| @ -0,0 +1,168 @@ | |||||
| # Created by .ignore support plugin (hsz.mobi) | |||||
| ### JetBrains template | |||||
| # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm | |||||
| # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | |||||
| settings/local.py | |||||
| logs/*.log | |||||
| # User-specific stuff: | |||||
| .idea/ | |||||
| # Sensitive or high-churn files: | |||||
| .idea/**/dataSources/ | |||||
| .idea/**/dataSources.ids | |||||
| .idea/**/dataSources.xml | |||||
| .idea/**/dataSources.local.xml | |||||
| .idea/**/sqlDataSources.xml | |||||
| .idea/**/dynamic.xml | |||||
| .idea/**/uiDesigner.xml | |||||
| # Gradle: | |||||
| .idea/**/gradle.xml | |||||
| .idea/**/libraries | |||||
| # CMake | |||||
| cmake-build-debug/ | |||||
| # Mongo Explorer plugin: | |||||
| .idea/**/mongoSettings.xml | |||||
| ## File-based project format: | |||||
| *.iws | |||||
| ## Plugin-specific files: | |||||
| # IntelliJ | |||||
| out/ | |||||
| # mpeltonen/sbt-idea plugin | |||||
| .idea_modules/ | |||||
| # JIRA plugin | |||||
| atlassian-ide-plugin.xml | |||||
| # Cursive Clojure plugin | |||||
| .idea/replstate.xml | |||||
| # Crashlytics plugin (for Android Studio and IntelliJ) | |||||
| com_crashlytics_export_strings.xml | |||||
| crashlytics.properties | |||||
| crashlytics-build.properties | |||||
| fabric.properties | |||||
| ### Python template | |||||
| # Byte-compiled / optimized / DLL files | |||||
| __pycache__/ | |||||
| *.py[cod] | |||||
| *$py.class | |||||
| # C extensions | |||||
| *.so | |||||
| # Distribution / packaging | |||||
| .Python | |||||
| build/ | |||||
| develop-eggs/ | |||||
| dist/ | |||||
| downloads/ | |||||
| eggs/ | |||||
| .eggs/ | |||||
| lib/ | |||||
| lib64/ | |||||
| parts/ | |||||
| sdist/ | |||||
| var/ | |||||
| wheels/ | |||||
| *.egg-info/ | |||||
| .installed.cfg | |||||
| *.egg | |||||
| # PyInstaller | |||||
| # Usually these files are written by a python script from a template | |||||
| # before PyInstaller builds the exe, so as to inject date/other infos into it. | |||||
| *.manifest | |||||
| *.spec | |||||
| # Installer logs | |||||
| pip-log.txt | |||||
| pip-delete-this-directory.txt | |||||
| # Unit test / coverage reports | |||||
| htmlcov/ | |||||
| .tox/ | |||||
| .coverage | |||||
| .coverage.* | |||||
| .cache | |||||
| nosetests.xml | |||||
| coverage.xml | |||||
| *.cover | |||||
| .hypothesis/ | |||||
| # Translations | |||||
| *.mo | |||||
| *.pot | |||||
| # Django stuff: | |||||
| *.log | |||||
| local_settings.py | |||||
| # Flask stuff: | |||||
| instance/ | |||||
| .webassets-cache | |||||
| # Scrapy stuff: | |||||
| .scrapy | |||||
| # Sphinx documentation | |||||
| docs/_build/ | |||||
| # PyBuilder | |||||
| target/ | |||||
| # Jupyter Notebook | |||||
| .ipynb_checkpoints | |||||
| # pyenv | |||||
| .python-version | |||||
| # celery beat schedule file | |||||
| celerybeat-schedule | |||||
| # SageMath parsed files | |||||
| *.sage.py | |||||
| # Environments | |||||
| .env | |||||
| .venv | |||||
| env/ | |||||
| venv/ | |||||
| ENV/ | |||||
| # Spyder project settings | |||||
| .spyderproject | |||||
| .spyproject | |||||
| # Rope project settings | |||||
| .ropeproject | |||||
| # mkdocs documentation | |||||
| /site | |||||
| # mypy | |||||
| .mypy_cache/ | |||||
| /tests/load_tests/logs/* | |||||
| /tests/.pytest_cache/ | |||||
| ws_test.py | |||||
| /.vscode/ | |||||
| .s3_cache/ | |||||
| mlruns | |||||
| *.pyc | |||||
| *.swp | |||||
| *.pt | |||||
| *.bin | |||||
| .vscode/ | |||||
| runs/ | |||||
| jupyters/custom_* | |||||
| *logs/ | |||||
| @ -0,0 +1,31 @@ | |||||
| repos: | |||||
| - repo: https://github.com/pre-commit/pre-commit-hooks | |||||
| rev: v2.2.3 | |||||
| hooks: | |||||
| - id: check-docstring-first | |||||
| stages: | |||||
| - commit | |||||
| - push | |||||
| - id: check-merge-conflict | |||||
| stages: | |||||
| - push | |||||
| - id: double-quote-string-fixer | |||||
| stages: | |||||
| - commit | |||||
| - push | |||||
| - id: fix-encoding-pragma | |||||
| stages: | |||||
| - commit | |||||
| - push | |||||
| - id: flake8 | |||||
| args: ['--config=setup.cfg'] | |||||
| stages: | |||||
| - commit | |||||
| - push | |||||
| - repo: https://github.com/pre-commit/mirrors-autopep8 | |||||
| rev: v1.4.4 | |||||
| hooks: | |||||
| - id: autopep8 | |||||
| stages: | |||||
| - commit | |||||
| - push | |||||
| @ -0,0 +1,4 @@ | |||||
| -r requirements.txt | |||||
| pytest | |||||
| pytest-cov | |||||
| pre-commit | |||||
| @ -0,0 +1,8 @@ | |||||
| taming-transformers==0.0.1 | |||||
| more_itertools==8.10.0 | |||||
| transformers==4.10.2 | |||||
| youtokentome==1.0.6 | |||||
| einops==0.3.2 | |||||
| torch | |||||
| torchvision | |||||
| matplotlib | |||||
| @ -0,0 +1,15 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| from .vae import get_vae | |||||
| from .dalle import get_rudalle_model | |||||
| from .tokenizer import get_tokenizer | |||||
| from . import vae, dalle, tokenizer | |||||
| __all__ = [ | |||||
| 'get_vae', | |||||
| 'get_rudalle_model', | |||||
| 'get_tokenizer', | |||||
| 'vae', | |||||
| 'dalle', | |||||
| 'tokenizer', | |||||
| ] | |||||
| @ -0,0 +1,76 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import os | |||||
| import torch | |||||
| from huggingface_hub import hf_hub_url, cached_download | |||||
| from .model import DalleModel | |||||
| from .fp16 import FP16Module | |||||
| MODELS = { | |||||
| 'Malevich': dict( | |||||
| description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, ' | |||||
| 'that uses Russian language and text+image multi-modality.', | |||||
| model_params=dict( | |||||
| num_layers=24, | |||||
| hidden_size=2048, | |||||
| num_attention_heads=16, | |||||
| embedding_dropout_prob=0.1, | |||||
| output_dropout_prob=0.1, | |||||
| attention_dropout_prob=0.1, | |||||
| image_tokens_per_dim=32, | |||||
| text_seq_length=128, | |||||
| use_masks=True, | |||||
| cogview_sandwich_layernorm=True, | |||||
| cogview_pb_relax=True, | |||||
| vocab_size=16384+128, | |||||
| image_vocab_size=8192, | |||||
| ), | |||||
| repo_id='sberbank-ai/rudalle-Malevich', | |||||
| filename='pytorch_model.bin', | |||||
| full_description='', # TODO | |||||
| ), | |||||
| 'small': dict( | |||||
| description='', | |||||
| model_params=dict( | |||||
| num_layers=12, | |||||
| hidden_size=768, | |||||
| num_attention_heads=12, | |||||
| embedding_dropout_prob=0.1, | |||||
| output_dropout_prob=0.1, | |||||
| attention_dropout_prob=0.1, | |||||
| image_tokens_per_dim=32, | |||||
| text_seq_length=128, | |||||
| use_masks=True, | |||||
| cogview_sandwich_layernorm=True, | |||||
| cogview_pb_relax=True, | |||||
| vocab_size=16384+128, | |||||
| image_vocab_size=8192, | |||||
| ), | |||||
| repo_id='', | |||||
| filename='', | |||||
| full_description='', # TODO | |||||
| ), | |||||
| } | |||||
| def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle'): | |||||
| # TODO docstring | |||||
| assert name in MODELS | |||||
| config = MODELS[name] | |||||
| model = DalleModel(device=device, fp16=fp16, **config['model_params']) | |||||
| if pretrained: | |||||
| cache_dir = os.path.join(cache_dir, name) | |||||
| config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) | |||||
| cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) | |||||
| checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu') | |||||
| model.load_state_dict(checkpoint) | |||||
| if fp16: | |||||
| model = FP16Module(model) | |||||
| model.eval() | |||||
| model = model.to(device) | |||||
| if config['description'] and pretrained: | |||||
| print(config['description']) | |||||
| return model | |||||
| @ -0,0 +1,60 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.autograd import Variable | |||||
| from torch.nn.parameter import Parameter | |||||
| FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) | |||||
| HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) | |||||
| def conversion_helper(val, conversion): | |||||
| """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" | |||||
| if not isinstance(val, (tuple, list)): | |||||
| return conversion(val) | |||||
| rtn = [conversion_helper(v, conversion) for v in val] | |||||
| if isinstance(val, tuple): | |||||
| rtn = tuple(rtn) | |||||
| return rtn | |||||
| def fp32_to_fp16(val): | |||||
| """Convert fp32 `val` to fp16""" | |||||
| def half_conversion(val): | |||||
| val_typecheck = val | |||||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||||
| val_typecheck = val.data | |||||
| if isinstance(val_typecheck, FLOAT_TYPES): | |||||
| val = val.half() | |||||
| return val | |||||
| return conversion_helper(val, half_conversion) | |||||
| def fp16_to_fp32(val): | |||||
| """Convert fp16 `val` to fp32""" | |||||
| def float_conversion(val): | |||||
| val_typecheck = val | |||||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||||
| val_typecheck = val.data | |||||
| if isinstance(val_typecheck, HALF_TYPES): | |||||
| val = val.float() | |||||
| return val | |||||
| return conversion_helper(val, float_conversion) | |||||
| class FP16Module(nn.Module): | |||||
| def __init__(self, module): | |||||
| super(FP16Module, self).__init__() | |||||
| self.add_module('module', module.half()) | |||||
| def forward(self, *inputs, **kwargs): | |||||
| return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) | |||||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||||
| return self.module.state_dict(destination, prefix, keep_vars) | |||||
| def load_state_dict(self, state_dict, strict=True): | |||||
| self.module.load_state_dict(state_dict, strict=strict) | |||||
| def get_param(self, item): | |||||
| return self.module.get_param(item) | |||||
| @ -0,0 +1,48 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import torch | |||||
| def _init_mask(text_tokens, image_tokens_per_dim): | |||||
| attn_size = text_tokens + image_tokens_per_dim**2 | |||||
| mask = torch.tril(torch.ones(attn_size, attn_size)) | |||||
| return mask | |||||
| def get_row_mask(text_tokens=256, image_tokens_per_dim=32): | |||||
| mask = _init_mask(text_tokens, image_tokens_per_dim) | |||||
| step = image_tokens_per_dim + 1 | |||||
| for col in range(text_tokens, mask.size(1)): | |||||
| mask[col + step:, col] = 0.0 | |||||
| return mask | |||||
| def get_col_mask(text_tokens=256, image_tokens_per_dim=32): | |||||
| mask = _init_mask(text_tokens, image_tokens_per_dim) | |||||
| step = image_tokens_per_dim - 1 | |||||
| for col in range(text_tokens, mask.size(1)): | |||||
| for i in range(1, mask.size(0), step+1): | |||||
| mask[col + i: col + i + step, col] = 0.0 | |||||
| return mask | |||||
| def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11): | |||||
| mask = _init_mask(text_tokens, image_tokens_per_dim) | |||||
| shift = kernel // 2 | |||||
| for pos in range(text_tokens, mask.size(1)): | |||||
| mask[pos+1:, pos] = 0.0 | |||||
| img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim) | |||||
| pixel_id = pos - text_tokens | |||||
| row = pixel_id // image_tokens_per_dim | |||||
| col = pixel_id % image_tokens_per_dim | |||||
| for r in range(-shift, shift+1): | |||||
| for c in range(-shift, shift+1): | |||||
| c_abs = (c + col) % image_tokens_per_dim | |||||
| r_abs = (r + row) % image_tokens_per_dim | |||||
| img[r_abs, c_abs] = 0.2 | |||||
| cell_id = r_abs * image_tokens_per_dim + c_abs | |||||
| if text_tokens + cell_id > pos: | |||||
| mask[text_tokens + cell_id, pos] = 1.0 | |||||
| img[row, col] = 1.0 | |||||
| return mask | |||||
| @ -0,0 +1,171 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from einops import rearrange | |||||
| from .utils import exists, is_empty, init_method_normal | |||||
| from .image_attention import get_conv_mask, get_row_mask, get_col_mask | |||||
| from .transformer import DalleTransformer | |||||
| class DalleModel(torch.nn.Module): | |||||
| def __init__(self, | |||||
| device, | |||||
| num_layers, | |||||
| vocab_size, | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| embedding_dropout_prob, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| text_seq_length=128, | |||||
| image_tokens_per_dim=32, | |||||
| image_vocab_size=16384, | |||||
| loss_img_weight=7, | |||||
| fp16=False, | |||||
| use_masks=True, | |||||
| cogview_sandwich_layernorm=False, | |||||
| cogview_pb_relax=False): | |||||
| super(DalleModel, self).__init__() | |||||
| self.device = device | |||||
| self.fp16 = fp16 | |||||
| self.image_tokens_per_dim = image_tokens_per_dim | |||||
| self.image_seq_length = image_tokens_per_dim ** 2 | |||||
| self.text_seq_length = text_seq_length | |||||
| self.total_seq_length = self.text_seq_length + self.image_seq_length | |||||
| self.total_vocab_size = vocab_size + image_vocab_size | |||||
| self.vocab_size = vocab_size | |||||
| self.loss_img_weight = loss_img_weight | |||||
| # TODO "to" | |||||
| mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim) | |||||
| if use_masks: | |||||
| self._mask_map = mask_map | |||||
| else: | |||||
| self._mask_map = [] | |||||
| init_method = init_method_normal(std=0.02) | |||||
| self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size) | |||||
| self.image_embeddings = torch.nn.Embedding(image_vocab_size, hidden_size) | |||||
| # Position embedding (serial). | |||||
| self.text_pos_embeddings = torch.nn.Embedding(text_seq_length + 1, hidden_size) | |||||
| self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size) | |||||
| self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size) | |||||
| init_method(self.text_pos_embeddings.weight) | |||||
| init_method(self.image_row_embeddings.weight) | |||||
| init_method(self.image_col_embeddings.weight) | |||||
| self.to_logits = torch.nn.Sequential( | |||||
| torch.nn.LayerNorm(hidden_size), | |||||
| torch.nn.Linear(hidden_size, self.total_vocab_size), | |||||
| ) | |||||
| # Embeddings dropout | |||||
| self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) | |||||
| # Transformer | |||||
| self.transformer = DalleTransformer( | |||||
| num_layers, | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| cogview_sandwich_layernorm=cogview_sandwich_layernorm, | |||||
| cogview_pb_relax=cogview_pb_relax, | |||||
| ) | |||||
| self.transformer._mask_map = self._mask_map | |||||
| def get_param(self, item): | |||||
| return getattr(self, item) | |||||
| def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim): | |||||
| row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device) | |||||
| col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device) | |||||
| conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device) | |||||
| if self.fp16: | |||||
| row_mask = row_mask.half() | |||||
| col_mask = col_mask.half() | |||||
| conv_mask = conv_mask.half() | |||||
| self.register_buffer('row_mask', row_mask) | |||||
| self.register_buffer('col_mask', col_mask) | |||||
| self.register_buffer('conv_mask', conv_mask) | |||||
| mask_map = [] | |||||
| for i in range(num_layers): | |||||
| if ((i - 1) % 4 == 0): | |||||
| mask_map.append(col_mask) | |||||
| elif i != num_layers - 1: | |||||
| mask_map.append(row_mask) | |||||
| else: | |||||
| mask_map.append(conv_mask) | |||||
| return mask_map | |||||
| def get_image_pos_embeddings(self, image_input_ids, past_length=0): | |||||
| input_shape = image_input_ids.size() | |||||
| row_ids = torch.arange(past_length, input_shape[-1] + past_length, | |||||
| dtype=torch.long, device=self.device) // self.image_tokens_per_dim | |||||
| row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1]) | |||||
| col_ids = torch.arange(past_length, input_shape[-1] + past_length, | |||||
| dtype=torch.long, device=self.device) % self.image_tokens_per_dim | |||||
| col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1]) | |||||
| return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids) | |||||
| def forward( | |||||
| self, | |||||
| input_ids, | |||||
| attention_mask, | |||||
| return_loss=False, | |||||
| has_cache=False, | |||||
| use_cache=False, | |||||
| ): | |||||
| text = input_ids[:, :self.text_seq_length] | |||||
| text_range = torch.arange(self.text_seq_length) | |||||
| text_range += (self.vocab_size - self.text_seq_length) | |||||
| text_range = text_range.to(self.device) | |||||
| text = torch.where(text == 0, text_range, text) | |||||
| # some hardcode :) | |||||
| text = F.pad(text, (1, 0), value=2) | |||||
| text_embeddings = self.text_embeddings(text) + \ | |||||
| self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device)) | |||||
| image_input_ids = input_ids[:, self.text_seq_length:] | |||||
| if exists(image_input_ids) and not is_empty(image_input_ids): | |||||
| image_embeddings = self.image_embeddings(image_input_ids) + \ | |||||
| self.get_image_pos_embeddings(image_input_ids, past_length=0) | |||||
| embeddings = torch.cat((text_embeddings, image_embeddings), dim=1) | |||||
| else: | |||||
| embeddings = text_embeddings | |||||
| # some hardcode :) | |||||
| if embeddings.shape[1] > self.total_seq_length: | |||||
| embeddings = embeddings[:, :-1] | |||||
| alpha = 0.1 | |||||
| embeddings = embeddings * alpha + embeddings.detach() * (1-alpha) | |||||
| attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]] | |||||
| transformer_output, present_has_cache = self.transformer( | |||||
| embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache) | |||||
| logits = self.to_logits(transformer_output) | |||||
| if return_loss is False: | |||||
| return logits, present_has_cache | |||||
| labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long() | |||||
| logits = rearrange(logits, 'b n c -> b c n') | |||||
| text_logits = logits[:, :self.vocab_size, :self.text_seq_length].contiguous().float() | |||||
| image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float() | |||||
| loss_text = F.cross_entropy( | |||||
| text_logits, | |||||
| labels[:, :self.text_seq_length]) | |||||
| loss_img = F.cross_entropy( | |||||
| image_logits, | |||||
| labels[:, self.text_seq_length:]) | |||||
| loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1) | |||||
| return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()} | |||||
| @ -0,0 +1,332 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import math | |||||
| import torch | |||||
| from torch.nn import LayerNorm | |||||
| from .utils import divide, split_tensor_along_last_dim | |||||
| @torch.jit.script | |||||
| def gelu_impl(x): | |||||
| """OpenAI's gelu implementation.""" | |||||
| return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) | |||||
| def gelu(x): | |||||
| return gelu_impl(x) | |||||
| class DalleTransformer(torch.nn.Module): | |||||
| """ | |||||
| This module takes input from embedding layer and it's output can | |||||
| be used directly by a logit layer. It consists of L (num-layers) | |||||
| blocks of: | |||||
| layer norm | |||||
| self attention | |||||
| residual connection | |||||
| layer norm | |||||
| mlp | |||||
| residual connection | |||||
| followed by a final layer norm. | |||||
| Arguments: | |||||
| num_layers: Number of transformer layers. | |||||
| hidden_size: The hidden size of the self attention. | |||||
| num_attention_heads: number of attention head in the self | |||||
| attention. | |||||
| attention_dropout_prob: dropout probability of the attention | |||||
| score in self attention. | |||||
| output_dropout_prob: dropout probability for the outputs | |||||
| after self attention and final output. | |||||
| layernorm_epsilon: epsilon used in layernorm to avoid | |||||
| division by zero. | |||||
| """ | |||||
| _mask_map = [] | |||||
| def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, | |||||
| layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False): | |||||
| super(DalleTransformer, self).__init__() | |||||
| # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf | |||||
| self.cogview_pb_relax = cogview_pb_relax | |||||
| # Transformer layers. | |||||
| self.layers = torch.nn.ModuleList([ | |||||
| DalleTransformerLayer( | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| layernorm_epsilon, | |||||
| cogview_sandwich_layernorm=cogview_sandwich_layernorm, | |||||
| cogview_pb_relax=cogview_pb_relax, | |||||
| ) for _ in range(num_layers) | |||||
| ]) | |||||
| # Final layer norm before output. | |||||
| self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||||
| def forward(self, hidden_states, attention_mask, has_cache, use_cache): | |||||
| for i, layer in enumerate(self.layers): | |||||
| mask = attention_mask | |||||
| if len(self._mask_map): | |||||
| layer_mask = self._mask_map[i][:mask.size(2), :mask.size(3)] | |||||
| mask = torch.mul(attention_mask, layer_mask) | |||||
| hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache) | |||||
| output = self.final_layernorm(hidden_states) | |||||
| return output, present_has_cache | |||||
| class DalleTransformerLayer(torch.nn.Module): | |||||
| """ | |||||
| A single layer transformer. | |||||
| We use the following notation: | |||||
| h: hidden size | |||||
| n: number of attention heads | |||||
| b: batch size | |||||
| s: sequence length | |||||
| Transformer layer takes input with size [b, s, h] and returns an | |||||
| output of the same size. | |||||
| Arguments: | |||||
| hidden_size: The hidden size of the self attention. | |||||
| num_attention_heads: number of attention head in the self | |||||
| attention. | |||||
| attention_dropout_prob: dropout probability of the attention | |||||
| score in self attention. | |||||
| output_dropout_prob: dropout probability for the outputs | |||||
| after self attention and final output. | |||||
| layernorm_epsilon: epsilon used in layernorm to avoid | |||||
| division by zero. | |||||
| """ | |||||
| def __init__(self, | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| layernorm_epsilon, | |||||
| cogview_sandwich_layernorm=False, | |||||
| cogview_pb_relax=False): | |||||
| super(DalleTransformerLayer, self).__init__() | |||||
| # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf | |||||
| self.cogview_sandwich_layernorm = cogview_sandwich_layernorm | |||||
| self.cogview_pb_relax = cogview_pb_relax | |||||
| # Layernorm on the input data. | |||||
| self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||||
| if self.cogview_sandwich_layernorm: | |||||
| self.before_first_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||||
| self.before_second_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||||
| # Self attention. | |||||
| self.attention = DalleSelfAttention( | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| cogview_pb_relax=cogview_pb_relax | |||||
| ) | |||||
| # Layernorm on the input data. | |||||
| self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) | |||||
| # MLP | |||||
| self.mlp = DalleMLP(hidden_size, output_dropout_prob) | |||||
| def forward(self, hidden_states, ltor_mask, has_cache, use_cache): | |||||
| # hidden_states: [b, s, h] | |||||
| # ltor_mask: [1, 1, s, s] | |||||
| # Layer norm at the begining of the transformer layer. | |||||
| layernorm_output = self.input_layernorm(hidden_states) | |||||
| # Self attention. | |||||
| attention_output, has_cache = self.attention( | |||||
| layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache) | |||||
| if self.cogview_sandwich_layernorm: | |||||
| attention_output = self.before_first_addition_layernorm(attention_output) | |||||
| # Residual connection. | |||||
| layernorm_input = hidden_states + attention_output | |||||
| # Layer norm post the self attention. | |||||
| layernorm_output = self.post_attention_layernorm(layernorm_input) | |||||
| # MLP. | |||||
| mlp_output = self.mlp(layernorm_output) | |||||
| if self.cogview_sandwich_layernorm: | |||||
| mlp_output = self.before_second_addition_layernorm(mlp_output) | |||||
| # Second residual connection. | |||||
| output = layernorm_input + mlp_output | |||||
| return output, has_cache | |||||
| class DalleSelfAttention(torch.nn.Module): | |||||
| """ | |||||
| Self-attention layer takes input with size [b, s, h] where b is | |||||
| the batch size, s is the sequence length, and h is the hidden size | |||||
| and creates output of the same size. | |||||
| Arguments: | |||||
| hidden_size: total hidden size of the layer (h). | |||||
| num_attention_heads: number of attention heads (n). Note that we | |||||
| require n to be divisible by number of GPUs | |||||
| used to parallelize the model. Also, we | |||||
| require hidden size to be divisible by n. | |||||
| attention_dropout_prob: dropout probability for the attention scores. | |||||
| output_dropout_prob: dropout probability for the output. | |||||
| We use the following notation: | |||||
| h: hidden_size | |||||
| n: num_attention_heads | |||||
| p: number of partitions | |||||
| np: n/p | |||||
| hp: h/p | |||||
| hn: h/n | |||||
| b: batch size | |||||
| s: sequence length | |||||
| """ | |||||
| def __init__(self, hidden_size, num_attention_heads, | |||||
| attention_dropout_prob, output_dropout_prob, cogview_pb_relax=False): | |||||
| super(DalleSelfAttention, self).__init__() | |||||
| # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf | |||||
| self.cogview_pb_relax = cogview_pb_relax | |||||
| self.hidden_size = hidden_size | |||||
| self.num_attention_heads = num_attention_heads | |||||
| self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) | |||||
| self.query_key_value = torch.nn.Linear(hidden_size, 3*hidden_size) | |||||
| self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) | |||||
| # Output. | |||||
| self.dense = torch.nn.Linear(hidden_size, hidden_size) | |||||
| self.output_dropout = torch.nn.Dropout(output_dropout_prob) | |||||
| def _transpose_for_scores(self, tensor): | |||||
| """ Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """ | |||||
| new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head) | |||||
| tensor = tensor.view(*new_tensor_shape) | |||||
| return tensor.permute(0, 2, 1, 3) | |||||
| def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask): | |||||
| key_t = key_layer.transpose(-1, -2) | |||||
| if self.cogview_pb_relax: | |||||
| attention_scores = torch.matmul( | |||||
| query_layer / math.sqrt(self.hidden_size_per_attention_head), | |||||
| key_t | |||||
| ) | |||||
| else: | |||||
| attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head) | |||||
| attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask) | |||||
| if self.cogview_pb_relax: | |||||
| # normalize attention scores. Should not affect resulting softmax value | |||||
| alpha = 32 | |||||
| attention_scores_scaled = attention_scores / alpha | |||||
| attention_scores_scaled_maxes, _ = attention_scores_scaled.detach().view( | |||||
| [attention_scores.size(0), attention_scores.size(1), -1] | |||||
| ).max(dim=-1) # max per head per sample | |||||
| attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand( | |||||
| [-1, -1, attention_scores.size(2), attention_scores.size(3)] | |||||
| ) # expand to [b, np, s, s] | |||||
| attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha | |||||
| return attention_scores | |||||
| def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,): | |||||
| # hidden_states: [b, s, h] | |||||
| # ltor_mask: [1, 1, s, s] | |||||
| # Attention heads. [b, s, hp] | |||||
| if has_cache and use_cache: | |||||
| mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :]) | |||||
| else: | |||||
| mixed_x_layer = self.query_key_value(hidden_states) | |||||
| (mixed_query_layer, | |||||
| mixed_key_layer, | |||||
| mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) | |||||
| query_layer = self._transpose_for_scores(mixed_query_layer) | |||||
| key_layer = self._transpose_for_scores(mixed_key_layer) | |||||
| value_layer = self._transpose_for_scores(mixed_value_layer) | |||||
| if use_cache and has_cache: | |||||
| value_layer = torch.cat((self.past_value, value_layer), dim=-2) | |||||
| query_layer = torch.cat((self.past_query, query_layer), dim=-2) | |||||
| key_layer = torch.cat((self.past_key, key_layer), dim=-2) | |||||
| attention_scores = self._calculate_attention_scores( | |||||
| query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask | |||||
| ) | |||||
| else: | |||||
| attention_scores = self._calculate_attention_scores( | |||||
| query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask | |||||
| ) | |||||
| if use_cache: | |||||
| self.past_query = query_layer | |||||
| self.past_key = key_layer | |||||
| self.past_value = value_layer | |||||
| has_cache = True | |||||
| else: | |||||
| has_cache = False | |||||
| # Attention probabilities. [b, np, s, s] | |||||
| attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) | |||||
| # This is actually dropping out entire tokens to attend to, which might | |||||
| # seem a bit unusual, but is taken from the original Transformer paper. | |||||
| attention_probs = self.attention_dropout(attention_probs) | |||||
| # Context layer. | |||||
| # [b, np, s, hn] | |||||
| context_layer = torch.matmul(attention_probs, value_layer) | |||||
| # [b, s, np, hn] | |||||
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||||
| new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) | |||||
| # [b, s, hp] | |||||
| context_layer = context_layer.view(*new_context_layer_shape) | |||||
| # Output. [b, s, h] | |||||
| output = self.dense(context_layer) | |||||
| output = self.output_dropout(output) | |||||
| return output, has_cache | |||||
| class DalleMLP(torch.nn.Module): | |||||
| """ | |||||
| MLP will take the input with h hidden state, project it to 4*h | |||||
| hidden dimension, perform gelu transformation, and project the | |||||
| state back into h hidden dimension. At the end, dropout is also | |||||
| applied. | |||||
| Arguments: | |||||
| hidden_size: The hidden size of the self attention. | |||||
| output_dropout_prob: dropout probability for the outputs | |||||
| after self attention and final output. | |||||
| """ | |||||
| def __init__(self, hidden_size, output_dropout_prob): | |||||
| super(DalleMLP, self).__init__() | |||||
| # Project to 4h. | |||||
| self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4*hidden_size) | |||||
| # Project back to h. | |||||
| self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size) | |||||
| self.dropout = torch.nn.Dropout(output_dropout_prob) | |||||
| def forward(self, hidden_states): | |||||
| # [b, s, 4hp] | |||||
| x = self.dense_h_to_4h(hidden_states) | |||||
| x = gelu(x) | |||||
| # [b, s, h] | |||||
| x = self.dense_4h_to_h(x) | |||||
| output = self.dropout(x) | |||||
| return output | |||||
| @ -0,0 +1,54 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import torch | |||||
| def exists(val): | |||||
| return val is not None | |||||
| def is_empty(t): | |||||
| return t.nelement() == 0 | |||||
| def ensure_divisibility(numerator, denominator): | |||||
| """Ensure that numerator is divisible by the denominator.""" | |||||
| assert numerator % denominator == 0, '{} is not divisible by {}'.format( | |||||
| numerator, denominator) | |||||
| def divide(numerator, denominator): | |||||
| """Ensure that numerator is divisible by the denominator and return | |||||
| the division value.""" | |||||
| ensure_divisibility(numerator, denominator) | |||||
| return numerator // denominator | |||||
| def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): | |||||
| """ | |||||
| Split a tensor along its last dimension. | |||||
| Arguments: | |||||
| tensor: input tensor. | |||||
| num_partitions: number of partitions to split the tensor | |||||
| contiguous_split_chunks: If True, make each chunk contiguous | |||||
| in memory. | |||||
| """ | |||||
| # Get the size and dimension. | |||||
| last_dim = tensor.dim() - 1 | |||||
| last_dim_size = divide(tensor.size()[last_dim], num_partitions) | |||||
| # Split. | |||||
| tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) | |||||
| # Note: torch.split does not create contiguous tensors by default. | |||||
| if contiguous_split_chunks: | |||||
| return tuple(chunk.contiguous() for chunk in tensor_list) | |||||
| return tensor_list | |||||
| def init_method_normal(std=0.02): | |||||
| """Init method based on normal distribution. | |||||
| This is only used for embeddings. The transformer has its | |||||
| own initializer. | |||||
| """ | |||||
| def init_(tensor): | |||||
| return torch.nn.init.normal_(tensor, mean=0.0, std=std) | |||||
| return init_ | |||||
| @ -0,0 +1,61 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import torch | |||||
| import torchvision | |||||
| import transformers | |||||
| import more_itertools | |||||
| import numpy as np | |||||
| import matplotlib.pyplot as plt | |||||
| from tqdm import tqdm | |||||
| from . import utils | |||||
| def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None, | |||||
| use_cache=True): | |||||
| # TODO docstring | |||||
| if seed is not None: | |||||
| utils.seed_everything(seed) | |||||
| vocab_size = dalle.get_param('vocab_size') | |||||
| text_seq_length = dalle.get_param('text_seq_length') | |||||
| image_seq_length = dalle.get_param('image_seq_length') | |||||
| total_seq_length = dalle.get_param('total_seq_length') | |||||
| device = dalle.get_param('device') | |||||
| text = text.lower().strip() | |||||
| input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length) | |||||
| pil_images, scores = [], [] | |||||
| for chunk in more_itertools.chunked(range(images_num), bs): | |||||
| chunk_bs = len(chunk) | |||||
| with torch.no_grad(): | |||||
| attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device)) | |||||
| out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device) | |||||
| has_cache = False | |||||
| sample_scores = [] | |||||
| for _ in tqdm(range(out.shape[1], total_seq_length)): | |||||
| logits, has_cache = dalle(out, attention_mask, | |||||
| has_cache=has_cache, use_cache=use_cache, return_loss=False) | |||||
| logits = logits[:, -1, vocab_size:] | |||||
| logits /= temperature | |||||
| filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |||||
| probs = torch.nn.functional.softmax(filtered_logits, dim=-1) | |||||
| sample = torch.multinomial(probs, 1) | |||||
| sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)]) | |||||
| out = torch.cat((out, sample), dim=-1) | |||||
| codebooks = out[:, -image_seq_length:] | |||||
| images = vae.decode(codebooks) | |||||
| pil_images += utils.torch_tensors_to_pil_list(images) | |||||
| scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist() | |||||
| return pil_images, scores | |||||
| def show(pil_images, nrow=4): | |||||
| imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) | |||||
| if not isinstance(imgs, list): | |||||
| imgs = [imgs.cpu()] | |||||
| fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14)) | |||||
| for i, img in enumerate(imgs): | |||||
| img = img.detach() | |||||
| img = torchvision.transforms.functional.to_pil_image(img) | |||||
| axs[0, i].imshow(np.asarray(img)) | |||||
| axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | |||||
| @ -0,0 +1,64 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| from os.path import join | |||||
| import torch | |||||
| import numpy as np | |||||
| import youtokentome as yttm | |||||
| from huggingface_hub import hf_hub_url, cached_download | |||||
| def get_tokenizer(path=None, cache_dir='/tmp/rudalle'): | |||||
| # TODO docstring | |||||
| if path is None: | |||||
| repo_id = 'shonenkov/rudalle-utils' | |||||
| filename = 'bpe.model' | |||||
| cache_dir = join(cache_dir, 'tokenizer') | |||||
| config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) | |||||
| cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) | |||||
| path = join(cache_dir, filename) | |||||
| tokenizer = YTTMTokenizerWrapper(yttm.BPE(model=path)) | |||||
| print('tokenizer --> ready') | |||||
| return tokenizer | |||||
| class YTTMTokenizerWrapper: | |||||
| eos_id = 3 | |||||
| bos_id = 2 | |||||
| unk_id = 1 | |||||
| pad_id = 0 | |||||
| def __init__(self, tokenizer): | |||||
| self.tokenizer = tokenizer | |||||
| def __len__(self): | |||||
| return self.vocab_size() | |||||
| def get_pad_token_id(self): | |||||
| # TODO docstring | |||||
| return self.tokenizer.subword_to_id('<PAD>') | |||||
| def vocab_size(self): | |||||
| # TODO docstring | |||||
| return self.tokenizer.vocab_size() | |||||
| def encode_text(self, text, text_seq_length, bpe_dropout=0.0): | |||||
| # TODO docstring | |||||
| tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=bpe_dropout)[0] | |||||
| tokens = [self.bos_id] + tokens + [self.eos_id] | |||||
| return self.prepare_tokens(tokens, text_seq_length) | |||||
| def decode_text(self, encoded): | |||||
| # TODO docstring | |||||
| return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ | |||||
| self.eos_id, self.bos_id, self.unk_id, self.pad_id | |||||
| ])[0] | |||||
| @staticmethod | |||||
| def prepare_tokens(tokens, text_seq_length): | |||||
| # TODO docstring | |||||
| empty_positions = text_seq_length - len(tokens) | |||||
| if empty_positions > 0: | |||||
| tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text | |||||
| if len(tokens) > text_seq_length: | |||||
| tokens = tokens[:text_seq_length] | |||||
| return torch.tensor(tokens).long() | |||||
| @ -0,0 +1,36 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import os | |||||
| import random | |||||
| import torch | |||||
| import torchvision | |||||
| import numpy as np | |||||
| def seed_everything(seed): | |||||
| random.seed(seed) | |||||
| os.environ['PYTHONHASHSEED'] = str(seed) | |||||
| np.random.seed(seed) | |||||
| torch.manual_seed(seed) | |||||
| torch.cuda.manual_seed(seed) | |||||
| torch.backends.cudnn.deterministic = True | |||||
| torch.backends.cudnn.benchmark = True | |||||
| def torch_tensors_to_pil_list(input_images): | |||||
| out_images = [] | |||||
| for in_image in input_images: | |||||
| in_image = in_image.cpu().detach() | |||||
| out_image = torchvision.transforms.functional.to_pil_image(in_image).convert('RGB') | |||||
| out_images.append(out_image) | |||||
| return out_images | |||||
| def pil_list_to_torch_tensors(pil_images): | |||||
| result = [] | |||||
| for pil_image in pil_images: | |||||
| image = np.array(pil_image, dtype=np.uint8) | |||||
| image = torch.from_numpy(image) | |||||
| image = image.permute(2, 0, 1).unsqueeze(0) | |||||
| result.append(image) | |||||
| return torch.cat(result, dim=0) | |||||
| @ -0,0 +1,24 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| from os.path import dirname, abspath, join | |||||
| import torch | |||||
| from huggingface_hub import hf_hub_url, cached_download | |||||
| from omegaconf import OmegaConf | |||||
| from .model import VQGanGumbelVAE | |||||
| def get_vae(pretrained=True, cache_dir='/tmp/rudalle'): | |||||
| # TODO | |||||
| config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml')) | |||||
| vae = VQGanGumbelVAE(config) | |||||
| if pretrained: | |||||
| repo_id = 'shonenkov/rudalle-utils' | |||||
| filename = 'vqgan.gumbelf8-sber.model.ckpt' | |||||
| cache_dir = join(cache_dir, 'vae') | |||||
| config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) | |||||
| cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) | |||||
| checkpoint = torch.load(join(cache_dir, filename), map_location='cpu') | |||||
| vae.model.load_state_dict(checkpoint['state_dict'], strict=False) | |||||
| print('vae --> ready') | |||||
| return vae | |||||
| @ -0,0 +1,100 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| from math import sqrt, log | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from torch import einsum | |||||
| from einops import rearrange | |||||
| from taming.modules.diffusionmodules.model import Encoder, Decoder | |||||
| class VQGanGumbelVAE(torch.nn.Module): | |||||
| def __init__(self, config): | |||||
| super().__init__() | |||||
| model = GumbelVQ( | |||||
| ddconfig=config.model.params.ddconfig, | |||||
| n_embed=config.model.params.n_embed, | |||||
| embed_dim=config.model.params.embed_dim, | |||||
| kl_weight=config.model.params.kl_weight, | |||||
| ) | |||||
| self.model = model | |||||
| self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2)) | |||||
| self.image_size = 256 | |||||
| self.num_tokens = config.model.params.n_embed | |||||
| @torch.no_grad() | |||||
| def get_codebook_indices(self, img): | |||||
| img = (2 * img) - 1 | |||||
| _, _, [_, _, indices] = self.model.encode(img) | |||||
| return rearrange(indices, 'b h w -> b (h w)') | |||||
| def decode(self, img_seq): | |||||
| b, n = img_seq.shape | |||||
| one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens).float() | |||||
| z = (one_hot_indices @ self.model.quantize.embed.weight) | |||||
| z = rearrange(z, 'b (h w) c -> b c h w', h=int(sqrt(n))) | |||||
| img = self.model.decode(z) | |||||
| img = (img.clamp(-1., 1.) + 1) * 0.5 | |||||
| return img | |||||
| class GumbelQuantize(nn.Module): | |||||
| """ | |||||
| credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) | |||||
| Gumbel Softmax trick quantizer | |||||
| Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 | |||||
| https://arxiv.org/abs/1611.01144 | |||||
| """ | |||||
| def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, | |||||
| kl_weight=5e-4, temp_init=1.0, use_vqinterface=True): | |||||
| super().__init__() | |||||
| self.embedding_dim = embedding_dim | |||||
| self.n_embed = n_embed | |||||
| self.straight_through = straight_through | |||||
| self.temperature = temp_init | |||||
| self.kl_weight = kl_weight | |||||
| self.proj = nn.Conv2d(num_hiddens, n_embed, 1) | |||||
| self.embed = nn.Embedding(self.n_embed, self.embedding_dim) | |||||
| self.use_vqinterface = use_vqinterface | |||||
| def forward(self, z, temp=None, return_logits=False): | |||||
| hard = self.straight_through if self.training else True | |||||
| temp = self.temperature if temp is None else temp | |||||
| logits = self.proj(z) | |||||
| soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) | |||||
| z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) | |||||
| # + kl divergence to the prior loss | |||||
| qy = F.softmax(logits, dim=1) | |||||
| diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() | |||||
| ind = soft_one_hot.argmax(dim=1) | |||||
| if self.use_vqinterface: | |||||
| if return_logits: | |||||
| return z_q, diff, (None, None, ind), logits | |||||
| return z_q, diff, (None, None, ind) | |||||
| return z_q, diff, ind | |||||
| class GumbelVQ(nn.Module): | |||||
| def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8): | |||||
| super().__init__() | |||||
| z_channels = ddconfig['z_channels'] | |||||
| self.encoder = Encoder(**ddconfig) | |||||
| self.decoder = Decoder(**ddconfig) | |||||
| self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0) | |||||
| self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) | |||||
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) | |||||
| def encode(self, x): | |||||
| h = self.encoder(x) | |||||
| h = self.quant_conv(h) | |||||
| quant, emb_loss, info = self.quantize(h) | |||||
| return quant, emb_loss, info | |||||
| def decode(self, quant): | |||||
| quant = self.post_quant_conv(quant) | |||||
| dec = self.decoder(quant) | |||||
| return dec | |||||
| @ -0,0 +1,34 @@ | |||||
| model: | |||||
| base_learning_rate: 4.5e-06 | |||||
| target: taming.models.vqgan.GumbelVQ | |||||
| params: | |||||
| kl_weight: 1.0e-08 | |||||
| embed_dim: 256 | |||||
| n_embed: 8192 | |||||
| monitor: val/rec_loss | |||||
| temperature_scheduler_config: | |||||
| target: taming.lr_scheduler.LambdaWarmUpCosineScheduler | |||||
| params: | |||||
| warm_up_steps: 0 | |||||
| max_decay_steps: 1000001 | |||||
| lr_start: 0.9 | |||||
| lr_max: 0.9 | |||||
| lr_min: 1.0e-06 | |||||
| ddconfig: | |||||
| double_z: false | |||||
| z_channels: 256 | |||||
| resolution: 256 | |||||
| in_channels: 3 | |||||
| out_ch: 3 | |||||
| ch: 128 | |||||
| ch_mult: | |||||
| - 1 | |||||
| - 1 | |||||
| - 2 | |||||
| - 4 | |||||
| num_res_blocks: 2 | |||||
| attn_resolutions: | |||||
| - 32 | |||||
| dropout: 0.0 | |||||
| lossconfig: | |||||
| target: taming.modules.losses.vqperceptual.DummyLoss | |||||
| @ -0,0 +1,13 @@ | |||||
| [pep8] | |||||
| max-line-length = 120 | |||||
| exclude = .tox,*migrations*,.json | |||||
| [flake8] | |||||
| max-line-length = 120 | |||||
| exclude = .tox,*migrations*,.json | |||||
| [autopep8-wrapper] | |||||
| exclude = .tox,*migrations*,.json | |||||
| [check-docstring-first] | |||||
| exclude = .tox,*migrations*,.json | |||||
| @ -0,0 +1,39 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import io | |||||
| from os.path import abspath, dirname | |||||
| import PIL | |||||
| import pytest | |||||
| import requests | |||||
| from rudalle import get_tokenizer, get_rudalle_model, get_vae | |||||
| TEST_ROOT = dirname(abspath(__file__)) | |||||
| @pytest.fixture(scope='module') | |||||
| def vae(): | |||||
| vae = get_vae(pretrained=False) | |||||
| yield vae | |||||
| @pytest.fixture(scope='module') | |||||
| def yttm_tokenizer(): | |||||
| tokenizer = get_tokenizer() | |||||
| yield tokenizer | |||||
| @pytest.fixture(scope='module') | |||||
| def sample_image(): | |||||
| url = 'https://cdn.kqed.org/wp-content/uploads/sites/12/2013/12/rudolph.png' | |||||
| resp = requests.get(url) | |||||
| resp.raise_for_status() | |||||
| image = PIL.Image.open(io.BytesIO(resp.content)) | |||||
| yield image | |||||
| @pytest.fixture(scope='module') | |||||
| def small_dalle(): | |||||
| model = get_rudalle_model('small', pretrained=False, fp16=False, device='cpu') | |||||
| return model | |||||
| @ -0,0 +1,31 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import torch | |||||
| import pytest | |||||
| from .test_vae import preprocess | |||||
| @pytest.mark.parametrize('text', [ | |||||
| 'мальчик играет с оленем', | |||||
| ]) | |||||
| def test_forward_step_and_criterion(text, sample_image, yttm_tokenizer, vae, small_dalle): | |||||
| bs = 4 | |||||
| text_seq_length = small_dalle.get_param('text_seq_length') | |||||
| total_seq_length = small_dalle.get_param('total_seq_length') | |||||
| device = small_dalle.get_param('device') | |||||
| img = sample_image.copy() | |||||
| img = preprocess(img, target_image_size=256) | |||||
| images = img.repeat(bs, 1, 1, 1).to(device) | |||||
| text = text.lower().strip() | |||||
| text_input_ids = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length) | |||||
| text_input_ids = text_input_ids.unsqueeze(0).repeat(bs, 1).to(device) | |||||
| attention_mask = torch.tril(torch.ones((bs, 1, total_seq_length, total_seq_length), device=device)) | |||||
| with torch.no_grad(): | |||||
| image_input_ids = vae.get_codebook_indices(images) | |||||
| input_ids = torch.cat((text_input_ids, image_input_ids), dim=1) | |||||
| loss, loss_values = small_dalle.forward(input_ids, attention_mask, return_loss=True) | |||||
| assert type(loss.data.detach().item()) == float | |||||
| assert type(loss_values) == dict | |||||
| @ -0,0 +1,17 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import pytest | |||||
| @pytest.mark.parametrize('text, text_seq_length, bpe_dropout', [ | |||||
| ('hello, how are you?', 128, 0.1), | |||||
| ('hello, how are you?', 128, 0.5), | |||||
| ('hello, how are you?', 128, 1.0), | |||||
| ('hello ... how are you ?', 256, 1.0), | |||||
| ('a person standing at a table with bottles of win', 64, 0.5), | |||||
| ('привет как дела???', 76, 0.0), | |||||
| ('клип на русском языке :)', 76, 0.1), | |||||
| ]) | |||||
| def test_encode_decode_text_yttm(yttm_tokenizer, text, text_seq_length, bpe_dropout): | |||||
| tokens = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length, bpe_dropout=bpe_dropout) | |||||
| decoded_text = yttm_tokenizer.decode_text(tokens) | |||||
| assert text == decoded_text | |||||
| @ -0,0 +1,49 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| import PIL | |||||
| import pytest | |||||
| import torch | |||||
| import torchvision.transforms as T | |||||
| import torchvision.transforms.functional as TF | |||||
| @pytest.mark.parametrize('target_image_size', [128, 192, 256]) | |||||
| def test_decode_vae(vae, sample_image, target_image_size): | |||||
| img = sample_image.copy() | |||||
| img = preprocess(img, target_image_size=target_image_size) | |||||
| with torch.no_grad(): | |||||
| img_seq = vae.get_codebook_indices(img) | |||||
| out_img = vae.decode(img_seq) | |||||
| assert out_img.shape == (1, 3, target_image_size, target_image_size) | |||||
| @pytest.mark.parametrize('target_image_size', [128, 192, 256]) | |||||
| def test_reconstruct_vae(vae, sample_image, target_image_size): | |||||
| img = sample_image.copy() | |||||
| with torch.no_grad(): | |||||
| x_vqgan = preprocess(img, target_image_size=target_image_size) | |||||
| output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model) | |||||
| assert output.shape == (1, 3, target_image_size, target_image_size) | |||||
| def preprocess(img, target_image_size=256): | |||||
| s = min(img.size) | |||||
| if s < target_image_size: | |||||
| raise ValueError(f'min dim for image {s} < {target_image_size}') | |||||
| r = target_image_size / s | |||||
| s = (round(r * img.size[1]), round(r * img.size[0])) | |||||
| img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) | |||||
| img = TF.center_crop(img, output_size=2 * [target_image_size]) | |||||
| img = torch.unsqueeze(T.ToTensor()(img), 0) | |||||
| return img | |||||
| def preprocess_vqgan(x): | |||||
| x = 2.*x - 1. | |||||
| return x | |||||
| def reconstruct_with_vqgan(x, model): | |||||
| z, _, [_, _, _] = model.encode(x) | |||||
| print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}') | |||||
| xrec = model.decode(z) | |||||
| return xrec | |||||