# -*- coding: utf-8 -*- from os.path import join import torch import numpy as np import youtokentome as yttm from huggingface_hub import hf_hub_url, cached_download def get_tokenizer(path=None, cache_dir='/tmp/rudalle'): # TODO docstring if path is None: repo_id = 'shonenkov/rudalle-utils' filename = 'bpe.model' cache_dir = join(cache_dir, 'tokenizer') config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) path = join(cache_dir, filename) tokenizer = YTTMTokenizerWrapper(yttm.BPE(model=path)) print('tokenizer --> ready') return tokenizer class YTTMTokenizerWrapper: eos_id = 3 bos_id = 2 unk_id = 1 pad_id = 0 def __init__(self, tokenizer): self.tokenizer = tokenizer def __len__(self): return self.vocab_size() def get_pad_token_id(self): # TODO docstring return self.tokenizer.subword_to_id('') def vocab_size(self): # TODO docstring return self.tokenizer.vocab_size() def encode_text(self, text, text_seq_length, bpe_dropout=0.0): # TODO docstring tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=bpe_dropout)[0] tokens = [self.bos_id] + tokens + [self.eos_id] return self.prepare_tokens(tokens, text_seq_length) def decode_text(self, encoded): # TODO docstring return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ self.eos_id, self.bos_id, self.unk_id, self.pad_id ])[0] @staticmethod def prepare_tokens(tokens, text_seq_length): # TODO docstring empty_positions = text_seq_length - len(tokens) if empty_positions > 0: tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text if len(tokens) > text_seq_length: tokens = tokens[:text_seq_length] return torch.tensor(tokens).long()