From ea2e911a1d377ca83318eeb37b38669acec76dac Mon Sep 17 00:00:00 2001 From: Arkhipkin Vladimir <31930051+oriBetelgeuse@users.noreply.github.com> Date: Sun, 7 Nov 2021 18:20:37 +0300 Subject: [PATCH] add crop_first for all type of image prompts (#38) * add crop_first for all type of image prompts * fix bugs with left/right --- rudalle/image_prompts.py | 26 +++++++++++++++++++++----- tests/test_image_prompts.py | 10 +++------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/rudalle/image_prompts.py b/rudalle/image_prompts.py index fd75247..e2778f4 100644 --- a/rudalle/image_prompts.py +++ b/rudalle/image_prompts.py @@ -27,9 +27,25 @@ class ImagePrompts: def _get_image_prompts(self, img, borders, vae, crop_first): if crop_first: - assert borders['right'] + borders['left'] + borders['down'] == 0 - up_border = borders['up'] * 8 - _, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :]) + bs, _, img_w, img_h = img.shape + vqg_img_w, vqg_img_h = img_w // 8, img_h // 8 + vqg_img = torch.zeros((bs, vqg_img_w, vqg_img_h), dtype=torch.int32, device=img.device) + if borders['down'] != 0: + down_border = borders['down'] * 8 + _, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :]) + vqg_img[:, -borders['down']:, :] = down_vqg_img + if borders['right'] != 0: + right_border = borders['right'] * 8 + _, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, -right_border:]) + vqg_img[:, :, -borders['right']:] = right_vqg_img + if borders['left'] != 0: + left_border = borders['left'] * 8 + _, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border]) + vqg_img[:, :, :borders['left']] = left_vqg_img + if borders['up'] != 0: + up_border = borders['up'] * 8 + _, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :]) + vqg_img[:, :borders['up'], :] = up_vqg_img else: _, _, [_, _, vqg_img] = vae.model.encode(img) @@ -40,9 +56,9 @@ class ImagePrompts: if borders['down'] != 0: mask[-borders['down']:, :] = 1. if borders['right'] != 0: - mask[:, :borders['right']] = 1. + mask[:, -borders['right']:] = 1. if borders['left'] != 0: - mask[:, -borders['left']:] = 1. + mask[:, :borders['left']] = 1. mask = mask.reshape(-1).bool() image_prompts = vqg_img.reshape((bs, -1)) diff --git a/tests/test_image_prompts.py b/tests/test_image_prompts.py index 27dcb93..5d9f309 100644 --- a/tests/test_image_prompts.py +++ b/tests/test_image_prompts.py @@ -13,10 +13,6 @@ def test_image_prompts(sample_image, vae, borders, crop_first): img = sample_image.copy() img = img.resize((256, 256)) image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first) - if crop_first: - assert image_prompt.image_prompts.shape[1] == borders['up'] * 32 - assert len(image_prompt.image_prompts_idx) == borders['up'] * 32 - else: - assert image_prompt.image_prompts.shape[1] == 32 * 32 - assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \ - + (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down']) + assert image_prompt.image_prompts.shape[1] == 32 * 32 + assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \ + + (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])