|
|
|
@ -10,8 +10,8 @@ from tqdm import tqdm |
|
|
|
from . import utils |
|
|
|
|
|
|
|
|
|
|
|
def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None, |
|
|
|
use_cache=True): |
|
|
|
def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image_prompts=None, temperature=1.0, bs=8, |
|
|
|
seed=None, use_cache=True): |
|
|
|
# TODO docstring |
|
|
|
if seed is not None: |
|
|
|
utils.seed_everything(seed) |
|
|
|
@ -32,16 +32,22 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, tempe |
|
|
|
out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device) |
|
|
|
has_cache = False |
|
|
|
sample_scores = [] |
|
|
|
for _ in tqdm(range(out.shape[1], total_seq_length)): |
|
|
|
logits, has_cache = dalle(out, attention_mask, |
|
|
|
has_cache=has_cache, use_cache=use_cache, return_loss=False) |
|
|
|
logits = logits[:, -1, vocab_size:] |
|
|
|
logits /= temperature |
|
|
|
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
|
probs = torch.nn.functional.softmax(filtered_logits, dim=-1) |
|
|
|
sample = torch.multinomial(probs, 1) |
|
|
|
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)]) |
|
|
|
out = torch.cat((out, sample), dim=-1) |
|
|
|
if image_prompts is not None: |
|
|
|
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts |
|
|
|
prompts = prompts.repeat(images_num, 1) |
|
|
|
for idx in tqdm(range(out.shape[1], total_seq_length)): |
|
|
|
if image_prompts is not None and idx in prompts_idx: |
|
|
|
out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1) |
|
|
|
else: |
|
|
|
logits, has_cache = dalle(out, attention_mask, |
|
|
|
has_cache=has_cache, use_cache=use_cache, return_loss=False) |
|
|
|
logits = logits[:, -1, vocab_size:] |
|
|
|
logits /= temperature |
|
|
|
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
|
probs = torch.nn.functional.softmax(filtered_logits, dim=-1) |
|
|
|
sample = torch.multinomial(probs, 1) |
|
|
|
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)]) |
|
|
|
out = torch.cat((out, sample), dim=-1) |
|
|
|
codebooks = out[:, -image_seq_length:] |
|
|
|
images = vae.decode(codebooks) |
|
|
|
pil_images += utils.torch_tensors_to_pil_list(images) |
|
|
|
|