You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

61 lines
2.6 KiB

# -*- coding: utf-8 -*-
import torch
import torchvision
import transformers
import more_itertools
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from . import utils
def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None,
use_cache=True):
# TODO docstring
if seed is not None:
utils.seed_everything(seed)
vocab_size = dalle.get_param('vocab_size')
text_seq_length = dalle.get_param('text_seq_length')
image_seq_length = dalle.get_param('image_seq_length')
total_seq_length = dalle.get_param('total_seq_length')
device = dalle.get_param('device')
text = text.lower().strip()
input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length)
pil_images, scores = [], []
for chunk in more_itertools.chunked(range(images_num), bs):
chunk_bs = len(chunk)
with torch.no_grad():
attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device))
out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device)
has_cache = False
sample_scores = []
for _ in tqdm(range(out.shape[1], total_seq_length)):
logits, has_cache = dalle(out, attention_mask,
has_cache=has_cache, use_cache=use_cache, return_loss=False)
logits = logits[:, -1, vocab_size:]
logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
sample = torch.multinomial(probs, 1)
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
out = torch.cat((out, sample), dim=-1)
codebooks = out[:, -image_seq_length:]
images = vae.decode(codebooks)
pil_images += utils.torch_tensors_to_pil_list(images)
scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist()
return pil_images, scores
def show(pil_images, nrow=4):
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
imgs = [imgs.cpu()]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
for i, img in enumerate(imgs):
img = img.detach()
img = torchvision.transforms.functional.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])