Browse Source

add image_prompts

pull/2/head
oriBetelgeuse 5 years ago
parent
commit
0000ee7a3d
2 changed files with 65 additions and 12 deletions
  1. +47
    -0
      rudalle/image_prompts.py
  2. +18
    -12
      rudalle/pipelines.py

+ 47
- 0
rudalle/image_prompts.py View File

@ -0,0 +1,47 @@
import torch
import numpy as np
class ImagePrompts:
def __init__(self, pil_image, borders, vae, device, crop_first=False):
self.device = device
self.vae = vae
self.__init_image_prompts(pil_image, borders, crop_first)
def __init_image_prompts(self, pil_image, borders, crop_first):
img = self.preprocess_img(pil_image)
self.image_prompts_idx, self.image_prompts = self.get_image_prompts(img, borders, crop_first)
def preprocess_img(self, pil_img):
img = torch.tensor(np.array(pil_img).transpose(2, 0, 1)) / 255.
img = img.unsqueeze(0).to(self.device, dtype=torch.float32)
return img
def get_image_prompts(self, img, borders, crop_first=False):
img = (2 * img) - 1
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 8
_, _, [_, _, vqg_img] = self.vae.model.encode(img[:, :, :up_border, :])
else:
_, _, [_, _, vqg_img] = self.vae.model.encode(img)
bs, vqg_img_w, vqg_img_h = vqg_img.shape
mask = torch.zeros(vqg_img_w, vqg_img_h)
if borders['up'] != 0:
mask[:borders['up'], :] = 1.
if borders['down'] != 0:
mask[-borders['down']:, :] = 1.
if borders['right'] != 0:
mask[:, :borders['right']] = 1.
if borders['left'] != 0:
mask[:, -borders['left']:] = 1.
mask = mask.reshape(-1).bool()
image_prompts = vqg_img.reshape((bs, -1))
image_prompts_idx = np.arange(vqg_img_w * vqg_img_h)
image_prompts_idx = set(image_prompts_idx[mask])
return image_prompts_idx, image_prompts

+ 18
- 12
rudalle/pipelines.py View File

@ -10,8 +10,8 @@ from tqdm import tqdm
from . import utils
def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None,
use_cache=True):
def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image_prompts=None, temperature=1.0, bs=8,
seed=None, use_cache=True):
# TODO docstring
if seed is not None:
utils.seed_everything(seed)
@ -32,16 +32,22 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, tempe
out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device)
has_cache = False
sample_scores = []
for _ in tqdm(range(out.shape[1], total_seq_length)):
logits, has_cache = dalle(out, attention_mask,
has_cache=has_cache, use_cache=use_cache, return_loss=False)
logits = logits[:, -1, vocab_size:]
logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
sample = torch.multinomial(probs, 1)
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
out = torch.cat((out, sample), dim=-1)
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
prompts = prompts.repeat(images_num, 1)
for idx in tqdm(range(out.shape[1], total_seq_length)):
if image_prompts is not None and idx in prompts_idx:
out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1)
else:
logits, has_cache = dalle(out, attention_mask,
has_cache=has_cache, use_cache=use_cache, return_loss=False)
logits = logits[:, -1, vocab_size:]
logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
sample = torch.multinomial(probs, 1)
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
out = torch.cat((out, sample), dim=-1)
codebooks = out[:, -image_seq_length:]
images = vae.decode(codebooks)
pil_images += utils.torch_tensors_to_pil_list(images)


Loading…
Cancel
Save