|
|
# -*- coding: utf-8 -*-
|
|
|
import torch
|
|
|
import torchvision
|
|
|
import transformers
|
|
|
import more_itertools
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from . import utils
|
|
|
|
|
|
|
|
|
def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None,
|
|
|
use_cache=True):
|
|
|
# TODO docstring
|
|
|
if seed is not None:
|
|
|
utils.seed_everything(seed)
|
|
|
|
|
|
vocab_size = dalle.get_param('vocab_size')
|
|
|
text_seq_length = dalle.get_param('text_seq_length')
|
|
|
image_seq_length = dalle.get_param('image_seq_length')
|
|
|
total_seq_length = dalle.get_param('total_seq_length')
|
|
|
device = dalle.get_param('device')
|
|
|
|
|
|
text = text.lower().strip()
|
|
|
input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length)
|
|
|
pil_images, scores = [], []
|
|
|
for chunk in more_itertools.chunked(range(images_num), bs):
|
|
|
chunk_bs = len(chunk)
|
|
|
with torch.no_grad():
|
|
|
attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device))
|
|
|
out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device)
|
|
|
has_cache = False
|
|
|
sample_scores = []
|
|
|
for _ in tqdm(range(out.shape[1], total_seq_length)):
|
|
|
logits, has_cache = dalle(out, attention_mask,
|
|
|
has_cache=has_cache, use_cache=use_cache, return_loss=False)
|
|
|
logits = logits[:, -1, vocab_size:]
|
|
|
logits /= temperature
|
|
|
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
|
|
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
|
|
|
sample = torch.multinomial(probs, 1)
|
|
|
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
|
|
|
out = torch.cat((out, sample), dim=-1)
|
|
|
codebooks = out[:, -image_seq_length:]
|
|
|
images = vae.decode(codebooks)
|
|
|
pil_images += utils.torch_tensors_to_pil_list(images)
|
|
|
scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist()
|
|
|
return pil_images, scores
|
|
|
|
|
|
|
|
|
def show(pil_images, nrow=4):
|
|
|
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
|
|
|
if not isinstance(imgs, list):
|
|
|
imgs = [imgs.cpu()]
|
|
|
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
|
|
|
for i, img in enumerate(imgs):
|
|
|
img = img.detach()
|
|
|
img = torchvision.transforms.functional.to_pil_image(img)
|
|
|
axs[0, i].imshow(np.asarray(img))
|
|
|
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
|