Compare commits

...

2 changed files with 24 additions and 12 deletions
Split View
  1. +21
    -5
      rudalle/image_prompts.py
  2. +3
    -7
      tests/test_image_prompts.py

+ 21
- 5
rudalle/image_prompts.py View File

@ -27,9 +27,25 @@ class ImagePrompts:
def _get_image_prompts(self, img, borders, vae, crop_first):
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 8
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :])
bs, _, img_w, img_h = img.shape
vqg_img_w, vqg_img_h = img_w // 8, img_h // 8
vqg_img = torch.zeros((bs, vqg_img_w, vqg_img_h), dtype=torch.int32, device=img.device)
if borders['down'] != 0:
down_border = borders['down'] * 8
_, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :])
vqg_img[:, -borders['down']:, :] = down_vqg_img
if borders['right'] != 0:
right_border = borders['right'] * 8
_, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, -right_border:])
vqg_img[:, :, -borders['right']:] = right_vqg_img
if borders['left'] != 0:
left_border = borders['left'] * 8
_, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border])
vqg_img[:, :, :borders['left']] = left_vqg_img
if borders['up'] != 0:
up_border = borders['up'] * 8
_, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :])
vqg_img[:, :borders['up'], :] = up_vqg_img
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)
@ -40,9 +56,9 @@ class ImagePrompts:
if borders['down'] != 0:
mask[-borders['down']:, :] = 1.
if borders['right'] != 0:
mask[:, :borders['right']] = 1.
mask[:, -borders['right']:] = 1.
if borders['left'] != 0:
mask[:, -borders['left']:] = 1.
mask[:, :borders['left']] = 1.
mask = mask.reshape(-1).bool()
image_prompts = vqg_img.reshape((bs, -1))


+ 3
- 7
tests/test_image_prompts.py View File

@ -13,10 +13,6 @@ def test_image_prompts(sample_image, vae, borders, crop_first):
img = sample_image.copy()
img = img.resize((256, 256))
image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first)
if crop_first:
assert image_prompt.image_prompts.shape[1] == borders['up'] * 32
assert len(image_prompt.image_prompts_idx) == borders['up'] * 32
else:
assert image_prompt.image_prompts.shape[1] == 32 * 32
assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
+ (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])
assert image_prompt.image_prompts.shape[1] == 32 * 32
assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
+ (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])

Loading…
Cancel
Save