Browse Source

fix changes

pull/3/head
Alex Shonenkov 5 years ago
parent
commit
ce44515ae8
4 changed files with 118 additions and 1 deletions
  1. +4
    -1
      rudalle/__init__.py
  2. +16
    -0
      rudalle/pipelines.py
  3. +30
    -0
      rudalle/ruclip/__init__.py
  4. +68
    -0
      rudalle/ruclip/processor.py

+ 4
- 1
rudalle/__init__.py View File

@ -3,7 +3,8 @@ from .vae import get_vae
from .dalle import get_rudalle_model
from .tokenizer import get_tokenizer
from .realesrgan import get_realesrgan
from . import vae, dalle, tokenizer, realesrgan, pipelines
from .ruclip import get_ruclip
from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip
__all__ = [
@ -11,8 +12,10 @@ __all__ = [
'get_rudalle_model',
'get_tokenizer',
'get_realesrgan',
'get_ruclip',
'vae',
'dalle',
'ruclip',
'tokenizer',
'realesrgan',
'pipelines',


+ 16
- 0
rudalle/pipelines.py View File

@ -58,6 +58,22 @@ def super_resolution(pil_images, realesrgan):
return result
def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu', count=4):
with torch.no_grad():
inputs = ruclip_processor(text=text, images=pil_images)
for key in inputs.keys():
inputs[key] = inputs[key].to(device)
outputs = ruclip(**inputs)
sims = outputs.logits_per_image.view(-1).softmax(dim=0)
items = []
for index, sim in enumerate(sims.cpu().numpy()):
items.append({'img_index': index, 'cosine': sim})
items = sorted(items, key=lambda x: x['cosine'], reverse=True)[:count]
top_pil_images = [pil_images[x['img_index']] for x in items]
top_scores = [pil_images[x['cosine']] for x in items]
return top_pil_images, top_scores
def show(pil_images, nrow=4):
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):


+ 30
- 0
rudalle/ruclip/__init__.py View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
import os
from transformers import CLIPModel
from huggingface_hub import hf_hub_url, cached_download
from .processor import RuCLIPProcessor
MODELS = {
'ruclip-vit-base-patch32-v5': dict(
repo_id='sberbank-ai/ru-clip',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
}
def get_ruclip(name, cache_dir='/tmp/rudalle'):
assert name in MODELS
config = MODELS[name]
repo_id = config['repo_id']
cache_dir = os.path.join(cache_dir, name)
for filename in config['filenames']:
config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{name}/{filename}')
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
ruclip = CLIPModel.from_pretrained(cache_dir)
ruclip_processor = RuCLIPProcessor.from_pretrained(cache_dir)
print('ruclip --> ready')
return ruclip, ruclip_processor

+ 68
- 0
rudalle/ruclip/processor.py View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
import os
import json
import torch
import youtokentome as yttm
import torchvision.transforms as T
from torch.nn.utils.rnn import pad_sequence
class RuCLIPProcessor:
eos_id = 3
bos_id = 2
unk_id = 1
pad_id = 0
def __init__(self, tokenizer_path, image_size=224, text_seq_length=76, mean=None, std=None):
self.tokenizer = yttm.BPE(tokenizer_path)
self.mean = mean or [0.485, 0.456, 0.406]
self.std = std or [0.229, 0.224, 0.225]
self.image_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)),
T.ToTensor(),
T.Normalize(mean=self.mean, std=self.std)
])
self.text_seq_length = text_seq_length
self.image_size = image_size
def encode_text(self, text):
text = text.lower()
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0]
tokens = [self.bos_id] + tokens + [self.eos_id]
tokens = tokens[:self.text_seq_length]
mask = [1] * len(tokens)
return torch.tensor(tokens).long(), torch.tensor(mask).long()
def decode_text(self, encoded):
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[
self.eos_id, self.bos_id, self.unk_id, self.pad_id
])[0]
def __call__(self, text=None, images=None, **kwargs):
inputs = {}
if text is not None:
input_ids, masks = [], []
texts = [text] if isinstance(text, str) else text
for text in texts:
tokens, mask = self.encode_text(text)
input_ids.append(tokens)
masks.append(mask)
inputs['input_ids'] = pad_sequence(input_ids, batch_first=True)
inputs['attention_mask'] = pad_sequence(masks, batch_first=True)
if images is not None:
pixel_values = []
for i, image in enumerate(images):
pixel_values.append(self.image_transform(image))
inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True)
return inputs
@classmethod
def from_pretrained(cls, folder):
tokenizer_path = os.path.join(folder, 'bpe.model')
config = json.load(open(os.path.join(folder, 'config.json')))
image_size = config['vision_config']['image_size']
text_seq_length = config['text_config']['max_position_embeddings'] - 1
mean, std = config.get('mean'), config.get('std')
return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std)

Loading…
Cancel
Save