|
|
|
@ -4,29 +4,33 @@ import numpy as np |
|
|
|
|
|
|
|
class ImagePrompts: |
|
|
|
|
|
|
|
def __init__(self, pil_image, borders, vae, device, crop_first=False): |
|
|
|
def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
pil_image (PIL.Image): image in PIL format |
|
|
|
borders (dict[str] | int): borders that we croped from pil_image |
|
|
|
vae (VQGanGumbelVAE): VQGAN model for image encoding |
|
|
|
device (str): cpu or cuda |
|
|
|
crop_first (bool): if True, croped image before VQGAN encoding |
|
|
|
""" |
|
|
|
self.device = device |
|
|
|
self.vae = vae |
|
|
|
self.__init_image_prompts(pil_image, borders, crop_first) |
|
|
|
img = self._preprocess_img(pil_image) |
|
|
|
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first) |
|
|
|
|
|
|
|
def __init_image_prompts(self, pil_image, borders, crop_first): |
|
|
|
img = self.preprocess_img(pil_image) |
|
|
|
self.image_prompts_idx, self.image_prompts = self.get_image_prompts(img, borders, crop_first) |
|
|
|
|
|
|
|
def preprocess_img(self, pil_img): |
|
|
|
def _preprocess_img(self, pil_img): |
|
|
|
img = torch.tensor(np.array(pil_img).transpose(2, 0, 1)) / 255. |
|
|
|
img = img.unsqueeze(0).to(self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
img = (2 * img) - 1 |
|
|
|
return img |
|
|
|
|
|
|
|
def get_image_prompts(self, img, borders, crop_first=False): |
|
|
|
img = (2 * img) - 1 |
|
|
|
@staticmethod |
|
|
|
def _get_image_prompts(img, borders, vae, crop_first=False): |
|
|
|
if crop_first: |
|
|
|
assert borders['right'] + borders['left'] + borders['down'] == 0 |
|
|
|
up_border = borders['up'] * 8 |
|
|
|
_, _, [_, _, vqg_img] = self.vae.model.encode(img[:, :, :up_border, :]) |
|
|
|
up_border = borders['up'] * 7 |
|
|
|
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :]) |
|
|
|
else: |
|
|
|
_, _, [_, _, vqg_img] = self.vae.model.encode(img) |
|
|
|
_, _, [_, _, vqg_img] = vae.model.encode(img) |
|
|
|
|
|
|
|
bs, vqg_img_w, vqg_img_h = vqg_img.shape |
|
|
|
mask = torch.zeros(vqg_img_w, vqg_img_h) |
|
|
|
|