Browse Source

fix bugs

pull/2/head
oriBetelgeuse 5 years ago
parent
commit
1c95f7158c
2 changed files with 22 additions and 14 deletions
  1. +18
    -14
      rudalle/image_prompts.py
  2. +4
    -0
      rudalle/pipelines.py

+ 18
- 14
rudalle/image_prompts.py View File

@ -4,29 +4,33 @@ import numpy as np
class ImagePrompts:
def __init__(self, pil_image, borders, vae, device, crop_first=False):
def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False):
"""
Args:
pil_image (PIL.Image): image in PIL format
borders (dict[str] | int): borders that we croped from pil_image
vae (VQGanGumbelVAE): VQGAN model for image encoding
device (str): cpu or cuda
crop_first (bool): if True, croped image before VQGAN encoding
"""
self.device = device
self.vae = vae
self.__init_image_prompts(pil_image, borders, crop_first)
img = self._preprocess_img(pil_image)
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first)
def __init_image_prompts(self, pil_image, borders, crop_first):
img = self.preprocess_img(pil_image)
self.image_prompts_idx, self.image_prompts = self.get_image_prompts(img, borders, crop_first)
def preprocess_img(self, pil_img):
def _preprocess_img(self, pil_img):
img = torch.tensor(np.array(pil_img).transpose(2, 0, 1)) / 255.
img = img.unsqueeze(0).to(self.device, dtype=torch.float32)
img = (2 * img) - 1
return img
def get_image_prompts(self, img, borders, crop_first=False):
img = (2 * img) - 1
@staticmethod
def _get_image_prompts(img, borders, vae, crop_first=False):
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 8
_, _, [_, _, vqg_img] = self.vae.model.encode(img[:, :, :up_border, :])
up_border = borders['up'] * 7
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :])
else:
_, _, [_, _, vqg_img] = self.vae.model.encode(img)
_, _, [_, _, vqg_img] = vae.model.encode(img)
bs, vqg_img_w, vqg_img_h = vqg_img.shape
mask = torch.zeros(vqg_img_w, vqg_img_h)


+ 4
- 0
rudalle/pipelines.py View File

@ -35,7 +35,11 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
prompts = prompts.repeat(images_num, 1)
if use_cache:
use_cache = False
print("Warning: use_cache changed to False")
for idx in tqdm(range(out.shape[1], total_seq_length)):
idx -= text_seq_lengths
if image_prompts is not None and idx in prompts_idx:
out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1)
else:


Loading…
Cancel
Save