|
|
@ -18,7 +18,6 @@ class ImagePrompts: |
|
|
self.device = device |
|
|
self.device = device |
|
|
img = self._preprocess_img(pil_image) |
|
|
img = self._preprocess_img(pil_image) |
|
|
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first) |
|
|
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first) |
|
|
self.allow_cache = True |
|
|
|
|
|
|
|
|
|
|
|
def _preprocess_img(self, pil_img): |
|
|
def _preprocess_img(self, pil_img): |
|
|
img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255. |
|
|
img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255. |
|
|
@ -26,17 +25,31 @@ class ImagePrompts: |
|
|
img = (2 * img) - 1 |
|
|
img = (2 * img) - 1 |
|
|
return img |
|
|
return img |
|
|
|
|
|
|
|
|
def _get_image_prompts(self, img, borders, vae, crop_first): |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def _get_image_prompts(img, borders, vae, crop_first): |
|
|
if crop_first: |
|
|
if crop_first: |
|
|
assert borders['right'] + borders['left'] + borders['down'] == 0 |
|
|
|
|
|
up_border = borders['up'] * 8 |
|
|
|
|
|
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :]) |
|
|
|
|
|
|
|
|
bs, _, img_w, img_h = img.shape |
|
|
|
|
|
vqg_img_w, vqg_img_h = img_w // 8, img_h // 8 |
|
|
|
|
|
vqg_img = torch.zeros((bs, vqg_img_w, vqg_img_h), dtype=torch.int32, device=img.device) |
|
|
|
|
|
if borders['down'] != 0: |
|
|
|
|
|
down_border = borders['down'] * 8 |
|
|
|
|
|
_, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :]) |
|
|
|
|
|
vqg_img[:, -borders['down']:, :] = down_vqg_img |
|
|
|
|
|
if borders['right'] != 0: |
|
|
|
|
|
right_border = borders['right'] * 8 |
|
|
|
|
|
_, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, :right_border]) |
|
|
|
|
|
vqg_img[:, :, :borders['right']] = right_vqg_img |
|
|
|
|
|
if borders['left'] != 0: |
|
|
|
|
|
left_border = borders['left'] * 8 |
|
|
|
|
|
_, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, -left_border:]) |
|
|
|
|
|
vqg_img[:, :, -borders['left']:] = left_vqg_img |
|
|
|
|
|
if borders['up'] != 0: |
|
|
|
|
|
up_border = borders['up'] * 8 |
|
|
|
|
|
_, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :]) |
|
|
|
|
|
vqg_img[:, :borders['up'], :] = up_vqg_img |
|
|
else: |
|
|
else: |
|
|
_, _, [_, _, vqg_img] = vae.model.encode(img) |
|
|
_, _, [_, _, vqg_img] = vae.model.encode(img) |
|
|
|
|
|
|
|
|
if borders['right'] + borders['left'] + borders['down'] != 0: |
|
|
|
|
|
self.allow_cache = False # TODO fix cache in attention |
|
|
|
|
|
|
|
|
|
|
|
bs, vqg_img_w, vqg_img_h = vqg_img.shape |
|
|
bs, vqg_img_w, vqg_img_h = vqg_img.shape |
|
|
mask = torch.zeros(vqg_img_w, vqg_img_h) |
|
|
mask = torch.zeros(vqg_img_w, vqg_img_h) |
|
|
if borders['up'] != 0: |
|
|
if borders['up'] != 0: |
|
|
|