| @ -1,8 +1,9 @@ | |||
| taming-transformers==0.0.1 | |||
| more_itertools==8.10.0 | |||
| transformers==4.10.2 | |||
| youtokentome==1.0.6 | |||
| einops==0.3.2 | |||
| more_itertools~=8.10.0 | |||
| transformers~=4.10.2 | |||
| youtokentome~=1.0.6 | |||
| omegaconf>=2.0.0 | |||
| einops~=0.3.2 | |||
| torch | |||
| torchvision | |||
| matplotlib | |||
| @ -0,0 +1,30 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| from transformers import CLIPModel | |||
| from huggingface_hub import hf_hub_url, cached_download | |||
| from .processor import RuCLIPProcessor | |||
| MODELS = { | |||
| 'ruclip-vit-base-patch32-v5': dict( | |||
| repo_id='sberbank-ai/ru-clip', | |||
| filenames=[ | |||
| 'bpe.model', 'config.json', 'pytorch_model.bin' | |||
| ] | |||
| ), | |||
| } | |||
| def get_ruclip(name, cache_dir='/tmp/rudalle'): | |||
| assert name in MODELS | |||
| config = MODELS[name] | |||
| repo_id = config['repo_id'] | |||
| cache_dir = os.path.join(cache_dir, name) | |||
| for filename in config['filenames']: | |||
| config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{name}/{filename}') | |||
| cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) | |||
| ruclip = CLIPModel.from_pretrained(cache_dir) | |||
| ruclip_processor = RuCLIPProcessor.from_pretrained(cache_dir) | |||
| print('ruclip --> ready') | |||
| return ruclip, ruclip_processor | |||
| @ -0,0 +1,68 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import json | |||
| import torch | |||
| import youtokentome as yttm | |||
| import torchvision.transforms as T | |||
| from torch.nn.utils.rnn import pad_sequence | |||
| class RuCLIPProcessor: | |||
| eos_id = 3 | |||
| bos_id = 2 | |||
| unk_id = 1 | |||
| pad_id = 0 | |||
| def __init__(self, tokenizer_path, image_size=224, text_seq_length=76, mean=None, std=None): | |||
| self.tokenizer = yttm.BPE(tokenizer_path) | |||
| self.mean = mean or [0.485, 0.456, 0.406] | |||
| self.std = std or [0.229, 0.224, 0.225] | |||
| self.image_transform = T.Compose([ | |||
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |||
| T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)), | |||
| T.ToTensor(), | |||
| T.Normalize(mean=self.mean, std=self.std) | |||
| ]) | |||
| self.text_seq_length = text_seq_length | |||
| self.image_size = image_size | |||
| def encode_text(self, text): | |||
| text = text.lower() | |||
| tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0] | |||
| tokens = [self.bos_id] + tokens + [self.eos_id] | |||
| tokens = tokens[:self.text_seq_length] | |||
| mask = [1] * len(tokens) | |||
| return torch.tensor(tokens).long(), torch.tensor(mask).long() | |||
| def decode_text(self, encoded): | |||
| return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ | |||
| self.eos_id, self.bos_id, self.unk_id, self.pad_id | |||
| ])[0] | |||
| def __call__(self, text=None, images=None, **kwargs): | |||
| inputs = {} | |||
| if text is not None: | |||
| input_ids, masks = [], [] | |||
| texts = [text] if isinstance(text, str) else text | |||
| for text in texts: | |||
| tokens, mask = self.encode_text(text) | |||
| input_ids.append(tokens) | |||
| masks.append(mask) | |||
| inputs['input_ids'] = pad_sequence(input_ids, batch_first=True) | |||
| inputs['attention_mask'] = pad_sequence(masks, batch_first=True) | |||
| if images is not None: | |||
| pixel_values = [] | |||
| for i, image in enumerate(images): | |||
| pixel_values.append(self.image_transform(image)) | |||
| inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True) | |||
| return inputs | |||
| @classmethod | |||
| def from_pretrained(cls, folder): | |||
| tokenizer_path = os.path.join(folder, 'bpe.model') | |||
| config = json.load(open(os.path.join(folder, 'config.json'))) | |||
| image_size = config['vision_config']['image_size'] | |||
| text_seq_length = config['text_config']['max_position_embeddings'] - 1 | |||
| mean, std = config.get('mean'), config.get('std') | |||
| return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std) | |||
| @ -0,0 +1,56 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import re | |||
| from setuptools import setup | |||
| def read(filename): | |||
| with open(os.path.join(os.path.dirname(__file__), filename)) as f: | |||
| file_content = f.read() | |||
| return file_content | |||
| def get_requirements(): | |||
| requirements = [] | |||
| for requirement in read('requirements.txt').splitlines(): | |||
| if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'): | |||
| parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement) | |||
| if parsed_requires: | |||
| package, version = parsed_requires[0] | |||
| requirements.append(f'{package}=={version}') | |||
| else: | |||
| print('WARNING! For correct matching dependency links need to specify package name and version' | |||
| 'such as <dependency url>#egg=<package_name>-<version>') | |||
| else: | |||
| requirements.append(requirement) | |||
| return requirements | |||
| def get_links(): | |||
| return [ | |||
| requirement for requirement in read('requirements.txt').splitlines() | |||
| if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+') | |||
| ] | |||
| def get_version(): | |||
| """ Get version from the package without actually importing it. """ | |||
| init = read('rudalle/__init__.py') | |||
| for line in init.split('\n'): | |||
| if line.startswith('__version__'): | |||
| return eval(line.split('=')[1]) | |||
| setup( | |||
| name='rudalle', | |||
| version=get_version(), | |||
| author='SberAI, SberDevices', | |||
| author_email='', | |||
| description='', | |||
| packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae'], | |||
| package_data={'rudalle/vae': ['*.yml']}, | |||
| install_requires=get_requirements(), | |||
| dependency_links=get_links(), | |||
| long_description=read('README.md'), | |||
| long_description_content_type='text/markdown', | |||
| ) | |||