Browse Source

some fixes (#30)

* some fixes, fill readme

* edit readme, up version
pull/31/head v0.0.1rc5
Alex 4 years ago
committed by GitHub
parent
commit
958923572b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 16 deletions
  1. +5
    -5
      README.md
  2. +1
    -1
      rudalle/__init__.py
  3. +5
    -2
      rudalle/image_prompts.py
  4. +28
    -8
      rudalle/pipelines.py
  5. +9
    -0
      tests/test_show.py

+ 5
- 5
README.md View File

@ -6,7 +6,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)
```
pip install rudalle==0.0.1rc4
pip install rudalle==0.0.1rc5
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
@ -18,13 +18,12 @@ pip install rudalle==0.0.1rc4
[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/shonenkov/rudalle-example-generation)
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/anton-l/rudall-e)
**English translation example**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing)
**Finetuning example**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Tb7J4PvvegWOybPfUubl5O7m5I24CBg5?usp=sharing)
**English translation example**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing)
### generation by ruDALLE:
```python
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
@ -95,4 +94,5 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
- [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference
- [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab)
- [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt
- [@Alex Wortega](https://github.com/AlexWortega) created first FREE version colab notebook with fine-tuning [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) on sneakers domain 💪
- [@Anton Lozhkov](https://github.com/anton-l) Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio), see [here](https://huggingface.co/spaces/anton-l/rudall-e)

+ 1
- 1
rudalle/__init__.py View File

@ -22,4 +22,4 @@ __all__ = [
'image_prompts',
]
__version__ = '0.0.1-rc4'
__version__ = '0.0.1-rc5'

+ 5
- 2
rudalle/image_prompts.py View File

@ -18,6 +18,7 @@ class ImagePrompts:
self.device = device
img = self._preprocess_img(pil_image)
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first)
self.allow_cache = True
def _preprocess_img(self, pil_img):
img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255.
@ -25,8 +26,7 @@ class ImagePrompts:
img = (2 * img) - 1
return img
@staticmethod
def _get_image_prompts(img, borders, vae, crop_first):
def _get_image_prompts(self, img, borders, vae, crop_first):
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 8
@ -34,6 +34,9 @@ class ImagePrompts:
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)
if borders['right'] + borders['left'] + borders['down'] != 0:
self.allow_cache = False # TODO fix cache in attention
bs, vqg_img_w, vqg_img_h = vqg_img.shape
mask = torch.zeros(vqg_img_w, vqg_img_h)
if borders['up'] != 0:


+ 28
- 8
rudalle/pipelines.py View File

@ -1,4 +1,8 @@
# -*- coding: utf-8 -*-
import os
from glob import glob
from os.path import join
import torch
import torchvision
import transformers
@ -34,10 +38,10 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
sample_scores = []
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
prompts = prompts.repeat(images_num, 1)
if use_cache:
use_cache = False
prompts = prompts.repeat(chunk_bs, 1)
if use_cache and image_prompts.allow_cache is False:
print('Warning: use_cache changed to False')
use_cache = False
for idx in tqdm(range(out.shape[1], total_seq_length)):
idx -= text_seq_length
if image_prompts is not None and idx in prompts_idx:
@ -84,7 +88,18 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu'
return top_pil_images, top_scores
def show(pil_images, nrow=4):
def show(pil_images, nrow=4, save_dir=None, show=True):
"""
:param pil_images: list of images in PIL
:param nrow: number of rows
:param save_dir: dir for separately saving of images, example: save_dir='./pics'
"""
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
count = len(glob(join(save_dir, 'img_*.png')))
for i, pil_image in enumerate(pil_images):
pil_image.save(join(save_dir, f'img_{count+i}.png'))
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
imgs = [imgs.cpu()]
@ -92,7 +107,12 @@ def show(pil_images, nrow=4):
for i, img in enumerate(imgs):
img = img.detach()
img = torchvision.transforms.functional.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
fix.show()
plt.show()
if save_dir is not None:
count = len(glob(join(save_dir, 'group_*.png')))
img.save(join(save_dir, f'group_{count+i}.png'))
if show:
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if show:
fix.show()
plt.show()

+ 9
- 0
tests/test_show.py View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from rudalle.pipelines import show
def test_show(sample_image):
img = sample_image.copy()
img = img.resize((256, 256))
pil_images = [img]*5
show(pil_images, nrow=2, save_dir='/tmp/pics', show=False)

Loading…
Cancel
Save