Browse Source

fill example readme

pull/4/head
shonenkov 5 years ago
parent
commit
b9e0e29c99
5 changed files with 58 additions and 7 deletions
  1. +55
    -0
      README.md
  2. BIN
      pics/rainbow-cherry-pick.png
  3. BIN
      pics/rainbow-full.png
  4. BIN
      pics/rainbow-super-resolution.png
  5. +3
    -7
      rudalle/dalle/__init__.py

+ 55
- 0
README.md View File

@ -1,4 +1,59 @@
# ruDALL-E
### Generate images from texts
```
pip install rudalle==0.0.1rc1
```
### Minimal Example:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wGE-046et27oHvNlBNPH07qrEQNE04PQ?usp=sharing)
### generation by ruDALLE:
```python
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle.utils import seed_everything
device = 'cuda'
dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)
realesrgan = get_realesrgan('x4', device=device)
tokenizer = get_tokenizer()
vae = get_vae().to(device)
ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5')
ruclip = ruclip.to(device)
text = 'изображение радуги на фоне ночного города'
pil_images = []
scores = []
seed_everything(42)
for top_k, top_p, images_num in [
(2048, 0.995, 3),
(1536, 0.99, 3),
(1024, 0.99, 3),
(1024, 0.98, 3),
(512, 0.97, 3),
(384, 0.96, 3),
(256, 0.95, 3),
(128, 0.95, 3),
]:
_pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p)
pil_images += _pil_images
scores += _scores
show(pil_images, 6)
```
![](./pics/rainbow-full.png)
### auto cherry-pick by ruCLIP:
```python
top_images, clip_scores = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=6)
show(top_images, 3)
```
![](./pics/rainbow-cherry-pick.png)
### super resolution:
```python
sr_images = super_resolution(top_images, realesrgan)
show(sr_images, 3)
```
![](./pics/rainbow-super-resolution.png)

BIN
pics/rainbow-cherry-pick.png View File

Before After
Width: 799  |  Height: 539  |  Size: 895 KiB

BIN
pics/rainbow-full.png View File

Before After
Width: 799  |  Height: 539  |  Size: 984 KiB

BIN
pics/rainbow-super-resolution.png View File

Before After
Width: 799  |  Height: 538  |  Size: 768 KiB

+ 3
- 7
rudalle/dalle/__init__.py View File

@ -25,11 +25,9 @@ MODELS = {
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
vocab_size=16384+128,
image_vocab_size=8192*2,
# image_vocab_size=8192,
image_vocab_size=8192,
),
# repo_id='sberbank-ai/rudalle-Malevich', # TODO update checkpoint
repo_id='shonenkov/rudalle-Malevich',
repo_id='sberbank-ai/rudalle-Malevich',
filename='pytorch_model.bin',
full_description='', # TODO
),
@ -68,9 +66,7 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu')
if 'module' in checkpoint:
checkpoint = checkpoint['module']
model.load_state_dict(checkpoint, strict=False)
model.load_state_dict(checkpoint)
if fp16:
model = FP16Module(model)
model.eval()


Loading…
Cancel
Save