Browse Source

Add cache for image prompts and crop first for all type of image prompts

feature/image_prompts
oriBetelgeuse 4 years ago
parent
commit
d1ff5f07c4
15 changed files with 921 additions and 245 deletions
  1. +1
    -0
      .gitignore
  2. +18
    -0
      .gitlab-ci.yml
  3. +32
    -1
      README.md
  4. +28
    -216
      jupyters/ruDALLE-example-generation-A100.ipynb
  5. +714
    -0
      jupyters/ruDALLE-image-prompts-dress-mannequins-V100.ipynb
  6. BIN
      pics/russian-temple-image-prompt.png
  7. +1
    -1
      rudalle/__init__.py
  8. +3
    -0
      rudalle/dalle/__init__.py
  9. +4
    -0
      rudalle/dalle/fp16.py
  10. +6
    -0
      rudalle/dalle/model.py
  11. +56
    -10
      rudalle/dalle/transformer.py
  12. +19
    -3
      rudalle/image_prompts.py
  13. +27
    -7
      rudalle/pipelines.py
  14. +3
    -7
      tests/test_image_prompts.py
  15. +9
    -0
      tests/test_show.py

+ 1
- 0
.gitignore View File

@ -166,3 +166,4 @@ runs/
jupyters/custom_*
*logs/
.DS_store

+ 18
- 0
.gitlab-ci.yml View File

@ -0,0 +1,18 @@
stages:
- test
all_branch_test:
stage: test
tags:
- docker
image: python:3.9
script:
- apt-get update ##[edited]
- apt-get install ffmpeg libsm6 libxext6 -y
- pip install cython
- pip install -r requirements-test.txt --no-cache-dir
- pip install codecov
- pytest --cov=rudalle tests/
- bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN
except:
- tags

+ 32
- 1
README.md View File

@ -1,15 +1,28 @@
# ruDALL-E
### Generate images from texts
[![Apache license](https://img.shields.io/badge/License-Apache-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0)
[![Coverage Status](https://codecov.io/gh/sberbank-ai/ru-dalle/branch/master/graphs/badge.svg)](https://codecov.io/gh/sberbank-ai/ru-dalle)
[![pipeline](https://gitlab.com/shonenkov/ru-dalle/badges/master/pipeline.svg)](https://gitlab.com/shonenkov/ru-dalle/-/pipelines)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)
```
pip install rudalle==0.0.1rc1
pip install rudalle==0.0.1rc5
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
### Minimal Example:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wGE-046et27oHvNlBNPH07qrEQNE04PQ?usp=sharing)
[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/shonenkov/rudalle-example-generation)
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/anton-l/rudall-e)
**Finetuning example**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Tb7J4PvvegWOybPfUubl5O7m5I24CBg5?usp=sharing)
**English translation example**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing)
### generation by ruDALLE:
```python
@ -65,3 +78,21 @@ show(sr_images, 3)
text, seed = 'красивая тян из аниме', 6955
```
![](./pics/anime-girl-super-resolution.png)
### Image Prompt
see `jupyters/ruDALLE-image-prompts-A100.ipynb`
```python
text, seed = 'Храм Василия Блаженного', 42
skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
```
![](./pics/russian-temple-image-prompt.png)
### 🚀 Contributors 🚀
- [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference
- [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab)
- [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt
- [@Alex Wortega](https://github.com/AlexWortega) created first FREE version colab notebook with fine-tuning [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) on sneakers domain 💪
- [@Anton Lozhkov](https://github.com/anton-l) Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio), see [here](https://huggingface.co/spaces/anton-l/rudall-e)

+ 28
- 216
jupyters/ruDALLE-example-generation-A100.ipynb
File diff suppressed because it is too large
View File


+ 714
- 0
jupyters/ruDALLE-image-prompts-dress-mannequins-V100.ipynb
File diff suppressed because it is too large
View File


BIN
pics/russian-temple-image-prompt.png View File

Before After
Width: 799  |  Height: 642  |  Size: 832 KiB

+ 1
- 1
rudalle/__init__.py View File

@ -22,4 +22,4 @@ __all__ = [
'image_prompts',
]
__version__ = '0.0.1-rc2'
__version__ = '0.0.1-rc6'

+ 3
- 0
rudalle/dalle/__init__.py View File

@ -59,6 +59,9 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir
# TODO docstring
assert name in MODELS
if fp16 and device == 'cpu':
print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')
config = MODELS[name]
model = DalleModel(device=device, fp16=fp16, **config['model_params'])
if pretrained:


+ 4
- 0
rudalle/dalle/fp16.py View File

@ -58,3 +58,7 @@ class FP16Module(nn.Module):
def get_param(self, item):
return self.module.get_param(item)
def to(self, device, *args, **kwargs):
self.module.to(device)
return super().to(device, *args, **kwargs)

+ 6
- 0
rudalle/dalle/model.py View File

@ -169,3 +169,9 @@ class DalleModel(torch.nn.Module):
loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()}
def to(self, device, *args, **kwargs):
self.device = device
self._mask_map = [mask.to(device) for mask in self._mask_map]
self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map]
return super().to(device, *args, **kwargs)

+ 56
- 10
rudalle/dalle/transformer.py View File

@ -146,7 +146,7 @@ class DalleTransformerLayer(torch.nn.Module):
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, has_cache = self.attention(
attention_output, att_has_cache = self.attention(
layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache)
if self.cogview_sandwich_layernorm:
@ -159,7 +159,8 @@ class DalleTransformerLayer(torch.nn.Module):
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
mlp_output, mlp_has_cache = self.mlp(
layernorm_output, has_cache=has_cache, use_cache=use_cache)
if self.cogview_sandwich_layernorm:
mlp_output = self.before_second_addition_layernorm(mlp_output)
@ -167,7 +168,7 @@ class DalleTransformerLayer(torch.nn.Module):
# Second residual connection.
output = layernorm_input + mlp_output
return output, has_cache
return output, att_has_cache and mlp_has_cache
class DalleSelfAttention(torch.nn.Module):
@ -212,6 +213,12 @@ class DalleSelfAttention(torch.nn.Module):
self.dense = torch.nn.Linear(hidden_size, hidden_size)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
# Cache
self.cash_size = 0
self.past_key = None
self.past_value = None
self.past_output = None
def _transpose_for_scores(self, tensor):
""" Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """
new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head)
@ -227,6 +234,7 @@ class DalleSelfAttention(torch.nn.Module):
)
else:
attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head)
ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:]
attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
if self.cogview_pb_relax:
# normalize attention scores. Should not affect resulting softmax value
@ -246,7 +254,8 @@ class DalleSelfAttention(torch.nn.Module):
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
if has_cache and use_cache:
mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :])
extra_cache_size = hidden_states.shape[-2] - self.chache_size
mixed_x_layer = self.query_key_value(hidden_states[:, -extra_cache_size:, :])
else:
mixed_x_layer = self.query_key_value(hidden_states)
@ -258,10 +267,10 @@ class DalleSelfAttention(torch.nn.Module):
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
# Can be simplified, but I didn't for readability's sake
if use_cache and has_cache:
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
query_layer = torch.cat((self.past_query, query_layer), dim=-2)
key_layer = torch.cat((self.past_key, key_layer), dim=-2)
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
attention_scores = self._calculate_attention_scores(
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
)
@ -271,13 +280,18 @@ class DalleSelfAttention(torch.nn.Module):
)
if use_cache:
self.past_query = query_layer
self.chache_size = hidden_states.shape[-2]
self.past_key = key_layer
self.past_value = value_layer
has_cache = True
else:
self.past_key = None
self.past_value = None
self.past_output = None
has_cache = False
if use_cache and has_cache:
attention_scores = attention_scores[..., -extra_cache_size:, :]
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
@ -298,6 +312,16 @@ class DalleSelfAttention(torch.nn.Module):
# Output. [b, s, h]
output = self.dense(context_layer)
if use_cache:
# Can be simplified, but I didn't for readability's sake
if has_cache:
output = torch.cat((self.past_output, output), dim=-2)
self.past_output = output
else:
self.past_output = output
has_cache = True
output = self.output_dropout(output)
return output, has_cache
@ -321,12 +345,34 @@ class DalleMLP(torch.nn.Module):
# Project back to h.
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(output_dropout_prob)
# MLP cache
self.cache_size = 0
self.past_x = None
def forward(self, hidden_states, has_cache=False, use_cache=False):
if has_cache and use_cache:
extra_cache_size = hidden_states.shape[-2] - self.cache_size
self.cache_size += extra_cache_size
hidden_states = hidden_states[:, -extra_cache_size:]
def forward(self, hidden_states):
# [b, s, 4hp]
x = self.dense_h_to_4h(hidden_states)
x = gelu(x)
# [b, s, h]
x = self.dense_4h_to_h(x)
if use_cache:
# Can be simplified, but I didn't for readability's sake
if has_cache:
x = torch.cat((self.past_x, x), dim=-2)
self.past_x = x
else:
self.cache_size = hidden_states.shape[-2]
self.past_x = x
has_cache = True
else:
self.past_x = None
has_cache = False
output = self.dropout(x)
return output
return output, has_cache

+ 19
- 3
rudalle/image_prompts.py View File

@ -28,9 +28,25 @@ class ImagePrompts:
@staticmethod
def _get_image_prompts(img, borders, vae, crop_first):
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 8
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :])
bs, _, img_w, img_h = img.shape
vqg_img_w, vqg_img_h = img_w // 8, img_h // 8
vqg_img = torch.zeros((bs, vqg_img_w, vqg_img_h), dtype=torch.int32, device=img.device)
if borders['down'] != 0:
down_border = borders['down'] * 8
_, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :])
vqg_img[:, -borders['down']:, :] = down_vqg_img
if borders['right'] != 0:
right_border = borders['right'] * 8
_, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, :right_border])
vqg_img[:, :, :borders['right']] = right_vqg_img
if borders['left'] != 0:
left_border = borders['left'] * 8
_, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, -left_border:])
vqg_img[:, :, -borders['left']:] = left_vqg_img
if borders['up'] != 0:
up_border = borders['up'] * 8
_, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :])
vqg_img[:, :borders['up'], :] = up_vqg_img
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)


+ 27
- 7
rudalle/pipelines.py View File

@ -1,11 +1,15 @@
# -*- coding: utf-8 -*-
import os
from glob import glob
from os.path import join
import torch
import torchvision
import transformers
import more_itertools
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm.auto import tqdm
from . import utils
@ -35,9 +39,7 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
prompts = prompts.repeat(images_num, 1)
if use_cache:
use_cache = False
print('Warning: use_cache changed to False')
prompts = prompts.repeat(chunk_bs, 1)
for idx in tqdm(range(out.shape[1], total_seq_length)):
idx -= text_seq_length
if image_prompts is not None and idx in prompts_idx:
@ -84,7 +86,18 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu'
return top_pil_images, top_scores
def show(pil_images, nrow=4):
def show(pil_images, nrow=4, save_dir=None, show=True):
"""
:param pil_images: list of images in PIL
:param nrow: number of rows
:param save_dir: dir for separately saving of images, example: save_dir='./pics'
"""
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
count = len(glob(join(save_dir, 'img_*.png')))
for i, pil_image in enumerate(pil_images):
pil_image.save(join(save_dir, f'img_{count+i}.png'))
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
imgs = [imgs.cpu()]
@ -92,5 +105,12 @@ def show(pil_images, nrow=4):
for i, img in enumerate(imgs):
img = img.detach()
img = torchvision.transforms.functional.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if save_dir is not None:
count = len(glob(join(save_dir, 'group_*.png')))
img.save(join(save_dir, f'group_{count+i}.png'))
if show:
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if show:
fix.show()
plt.show()

+ 3
- 7
tests/test_image_prompts.py View File

@ -13,10 +13,6 @@ def test_image_prompts(sample_image, vae, borders, crop_first):
img = sample_image.copy()
img = img.resize((256, 256))
image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first)
if crop_first:
assert image_prompt.image_prompts.shape[1] == borders['up'] * 32
assert len(image_prompt.image_prompts_idx) == borders['up'] * 32
else:
assert image_prompt.image_prompts.shape[1] == 32 * 32
assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
+ (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])
assert image_prompt.image_prompts.shape[1] == 32 * 32
assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
+ (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])

+ 9
- 0
tests/test_show.py View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from rudalle.pipelines import show
def test_show(sample_image):
img = sample_image.copy()
img = img.resize((256, 256))
pil_images = [img]*5
show(pil_images, nrow=2, save_dir='/tmp/pics', show=False)

Loading…
Cancel
Save